DeepAutoEncoderを利用した画像、時系列データの異常検知入門(keras編)

この記事はアイスタイルアドベントカレンダー の13日目の記事です。

はじめに

こんにちは。
7月からアイスタイルに中途入社しました、R&Dのkawakamikです。
毎日@cosmeのデータをこねくり回してあーでもないこーでもないと業務を進めています。

こねくり回しているデータは膨大なので、極稀に変則的なデータも存在します。
こういったデータは異常値として自動的に検出できればいいなあ…

という流れで、DeepAutoEncoderを利用した異常データ検出入門について書きます。

DeepAutoEncoderとは

AutoEncoderの層を増やしたニューラルネットワークです。
入力、出力を同一のデータ、中間層のサイズを入力層、出力層より小さいサイズにすることで、
DeepAutoEncoderは情報量を圧縮する方法を学習します。
この技術を利用し、画像の異常検知と時系列データの異常検知を行います。

…と文字に起こしてもよくわからないと思うので、別の例に例えて説明します。

あるところに、猫の絵を専門に書く画家がいました。
この画家が今回のモデル(DeepAutoEncoder)です。

この画家は猫が大好きなため、子供の頃は猫の観察ばかりしていました。

そのおかげで、大人になった画家は猫と猫の絵の違いがわからないぐらい精巧な猫の絵を書くことが可能です。

すごい才能を持った画家ですが、大きな問題があります。
それは、猫以外の絵がまともに書けないという問題です。
たとえば、犬を見せて書かせると、犬と猫を混ぜたような絵だったり、何を書いたかわからない絵ができあがります。

なので、この画家に画像を見せて、画家が書いた絵を見ることで、見せた画像が猫画像かそれ以外の画像かを判別することが可能になります。
この画家を現実問題に置き換えると、

  • 正常画像を利用して、学習を行う
  • 学習したモデルに画像を入力する
  • モデルから出力された画像を確認し、入力された画像が正常か異常かを判断する

という流れができます。

開発環境

コードはGoogle Colaboratory上で作成しました。
モデル作成に利用するtensorflowは1.15.0を利用しています。

画像データの異常検知

まずは、画像データの異常検知を行います。
利用する画像データはMNISTと呼ばれるデータセットで、手書き数字画像60000枚とテスト画像10000枚を集めた画像データセットとなっています。
以下にデータセットの一例を示します。
左から、7、2,1,0,4の手書き画像となっています。

モデルの学習

このデータセットから、数字の1が書かれた画像を正常画像、それ以外の数字が書かれた画像を異常画像として設定し、異常検知を行います。
以下コードでデータセット取得、学習を実行します。

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

# DeepAutoEncoderモデルの構築
input_layer = Input(shape=(784,))
encoded = Dense(300, activation="relu")(input_layer)
encoded = Dense(100, activation="relu")(encoded)

decoded = Dense(300, activation="relu")(encoded)
output_layer = Dense(784, activation="sigmoid")(decoded)

model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='binary_crossentropy', optimizer="adam") 

# MNISTデータの読み込み
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 正規化
x_train = np.array([ (x.astype('float32') / 255.0).flatten() for x in x_train ])
x_test = np.array([ (x.astype('float32') / 255.0).flatten() for x in x_test ])

# 1と1以外のデータに分ける
x_train_normal = np.array([ x for x, y in zip(x_train, y_train) if y == 1 ])
x_test_normal = np.array([ x for x, y in zip(x_test, y_test) if y == 1 ])
x_test_anomaly = np.array([ x for x, y in zip(x_test, y_test) if y != 1 ])

# 学習
epochs = 100
batch_size = 32
validation_split = 0.25
model.fit(x_train_normal, x_train_normal, epochs=epochs, batch_size=batch_size, validation_split=validation_split)

モデルの出力確認

数字の1が書かれた画像のみで学習が完了しました。
学習したモデルを利用して、数字の1が書かれた画像を入力したら、どんな出力が出るのかを確認します。

# 出力画像を作成
predict_normal_data = model.predict(x_test_normal)

# 出力された画像と元の画像を並べて表示
for i in range(3):
  plt.figure()
  plt.subplot(1, 2, 1)
  plt.imshow(x_test_normal[i].reshape(28, 28), cmap='Greys')
  plt.axis('off')
  plt.title("input image")
  plt.subplot(1, 2, 2)
  plt.imshow(predict_normal_data[i].reshape(28, 28), cmap='Greys')
  plt.axis('off')
  plt.title("output image")

左がモデルに入力した画像、右がそれに対するモデルの出力画像です。
どうやら、手書きの1をうまく再現して出力できていそうです。
では、数字の1以外が書かれた画像を入力するとどうなるのかも見てみましょう。

# 出力画像を作成
predict_anomaly_data = model.predict(x_test_anomaly)

# 出力された画像と元の画像を並べて表示
for i in range(3):
  plt.figure()
  plt.subplot(1, 2, 1)
  plt.imshow(x_test_anomaly[i].reshape(28, 28), cmap='Greys')
  plt.axis('off')
  plt.title("input image")
  plt.subplot(1, 2, 2)
  plt.imshow(predict_anomaly_data[i].reshape(28, 28), cmap='Greys')
  plt.axis('off')
  plt.title("output image")

出力された画像はところどころ欠損が見られます。
このような出力がされた場合、異常と判断して良さそうです。

正常画像と異常画像の判別手法

目で正常画像と異常画像が確認できたので、次はコードで正常画像と異常画像を判別します。
判別指標として、入力画像と出力画像の平均二乗誤差(以下MSEとする)を利用します。
この値が高ければ異常画像、低ければ正常画像として判別します。

以下のコードで全画像のMSEをヒストグラムで表します。

normal_data_mse = [ mean_squared_error(x, y) for x, y in zip(predict_normal_data, x_test_normal) ]
anomaly_data_mse = [ mean_squared_error(x, y) for x, y in zip(predict_anomaly_data, x_test_anomaly) ]

plt.figure(figsize=(12, 8))
plt.hist(normal_data_mse, bins=40, color="blue", alpha=0.5, label="normal")
plt.hist(anomaly_data_mse, bins=100, color="red", alpha=0.5, label="anomaly")
plt.title("MSE hist")
plt.xlabel("MSE")
plt.ylabel("freq")
plt.legend(fontsize=12)

青いヒストグラムは正常画像のMSE、赤いヒストグラムが異常画像のMSEを示しています。
おおよそ正常画像と異常画像が分けられそうですね。

時系列データの異常検知

次は時系列データの異常検知を行います。
時系列データも画像と同じ用に、時系列データを入力したら、入力データを再現できたか、そうでないかで正常か異常かを判断します。

時系列データの異常検知では趣向を変えて、
上司「大量にデータを貯めたから、AIでちゃちゃっと異常データを見つけてくれ」
という上司の無茶振り(実体験)に対応することにしましょう。

時系列データの用意

なにはともあれ、まずは時系列データを用意します。
今回はダミーの時系列データを作成します。
以下に時系列データの作成とデータの可視化を行うコードを示しますが、コードを読まずに可視化されたデータから異常データを当ててみるのに挑戦するのもいいでしょう。

# 時系列データ長を100とする
input_data_length = 100

# データ長100のsin波にノイズ大を足したデータを10000個作成
small_noise_sin = np.array([ np.sin( np.linspace(0, np.pi*2, input_data_length) ) + np.random.randn(input_data_length) * 0.5 for _ in range(10000) ])

# データ長100のsin波にノイズ小を足したデータを100個作成
big_noise_sin = np.array([ np.sin( np.linspace(0, np.pi*2, input_data_length) ) + np.random.randn(input_data_length) * 0.1 for _ in range(100) ])

# データ長100のcos波にノイズ小を足したデータを100個作成
small_noise_cos = np.array([ np.cos( np.linspace(0, np.pi*2, input_data_length) ) * 3 + np.random.randn(input_data_length) * 0.5 for _ in range(100) ])

# データ結合
input_data = np.vstack([small_noise_sin, big_noise_sin, small_noise_cos])

# データをシャッフル
np.random.seed(seed=42)
np.random.shuffle(input_data)

# 全時系列データを可視化
plt.figure(figsize=(12, 8))
_ = [ plt.plot(np.arange(0, input_data_length), x) for x in input_data ]

今回は時系列データ長100のデータを10200個用意しました。
10200個描画してみましたが、正常データと異常データを見分けることはできましたか?。
では、DeepAutoEncoderで正常データ、異常データを分けてみましょう。

問題設定

今回のケースのように、正常データと異常データの判断がつかない時、まず取るべきアプローチは、
手元にあるデータすべてを正常データとみなしてモデルに入力してみることです。
そうすると、モデルは全データの大まかな特徴を掴んでくれるようになります。

その結果、以下のような大雑把な定義が設定できるようになります。

  • 大まかな特徴が多いデータは多数派のデータ=正常データ
  • 大まかな特徴が少ないデータは少数派のデータ=異常データ

この定義に基づいて正常データと異常データ分けを進めていきましょう。

モデルの学習

まずは、以下コードで学習を実行します。

from sklearn import preprocessing

# DeepAutoEncoderモデルの構築
input_layer = Input(shape=(100,))
encoded = Dense(50, activation="relu")(input_layer)
encoded = Dense(25, activation="relu")(encoded)

decoded = Dense(50, activation="relu")(encoded)
output_layer = Dense(100, activation="sigmoid")(decoded)

model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='mse', optimizer="adam")

# 正規化
ss = preprocessing.StandardScaler()
input_data_scaler = ss.fit_transform(input_data)

# 学習
epochs = 100
batch_size = 32
validation_split = 0.25
model.fit(input_data_scaler, input_data_scaler, epochs=epochs, batch_size=batch_size, validation_split=validation_split)

MSEの確認

学習したモデルを利用し、今回はまず全データのMSEを確認してみます。
以下に全データMSEのヒストグラムを示します。

# モデルに時系列データを入力し、モデルが推定した時系列データを出力させる
predict_data = model.predict(input_data_scaler)

# MSEを計算
all_data_mse = [ mean_squared_error(x, y) for x, y in zip(predict, input_data_scaler) ]

# MSEをヒストグラムで表示
plt.figure(figsize=(12, 8))
plt.hist(all_data_mse, bins=100, color="blue", alpha=0.5)
plt.title("MSE hist")
plt.xlabel("MSE")
plt.ylabel("freq")
plt.legend(fontsize=12)

MSEのヒストグラムはどうやらMSE0~2とMSE12~14の2つのグループに分けられそうです。

MSEが高いデータを描画してみる

MSEが高めな、MSE12~14のグループは異常データと言えるかもしれません。
対象のデータがどのようなデータなのか描画して確認しましょう。

# MSEが10以上の時系列データのindexを取得する
anomaly_index = np.where(np.array(all_data_mse) >= 10)[0]

# MSE10未満の時系列データを青で描画、MSE10以上の時系列データを赤で描画
plt.figure(figsize=(12, 8))
_ = [ plt.plot(np.arange(0, input_data_length), x, color="blue") for x in input_data ]
_ = [ plt.plot(np.arange(0, input_data_length), x, color="red") for x in input_data[anomaly_index] ]

青がMSE10未満の時系列データ、赤がMSE10以上の時系列データです。
見た目からして異常なデータが抽出されているのが確認できますね。
異常データが抽出できました。めでたしめでたし。

…ちょっとまってください。
見ただけでわかるデータが異常ならば、異常検知を使うまでもありません。
人間が見ただけでわからない異常データが見つかれば万々歳です。
なので、もう少しデータを深堀りしてみましょう。

MSEが低いデータを描画してみる

MSEのヒストグラムを再度確認します。

すると、ヒストグラムの0付近にデータが集まっている事が確認できます。
このデータを取り出して描画してみましょう。

# MSEが0.1以下の時系列データのindexを取得する
anomaly_index = np.where(np.array(all_data_mse) <= 0.1)[0]

# MSE0.1より上の時系列データを青で描画、MSE0.1以下の時系列データを赤で描画
plt.figure(figsize=(12, 8))
_ = [ plt.plot(np.arange(0, input_data_length), x, color="blue") for x in input_data ]
_ = [ plt.plot(np.arange(0, input_data_length), x, color="red") for x in input_data[anomaly_index] ]

MSE0.1以下のデータを赤く描画、それ以外のデータを青く描画してみました。
すると、明らかにノイズが少ないデータが見えてきました。
これは人間が見ただけではわからない異常と言えますね。

あとがき

DeepAutoEncoderを利用した異常検知の基本を書きました。
異常検知に触れたことが無い方が大まかな流れをつかめたら幸いです。

明日はsugatoさんの「matplotlibを利用して、20年間のクチコミデータを動的に可視化する」です。

2019年7月中途入社。 アイスタイルのレコメンドエンジン開発に関わっています。 趣味は猫とお酒とIoT