DQN(Deep Q Network)で自作パズルを解く

はじめに

こんにちは。データ分析などを担当しているtanitです。
この記事はistyleアドベントカレンダーの19日目の記事です。

概要

今回はDQN(Deep Q Network)で自作のパズルを解けるようにしたいと思います。

まず、強化学習について簡単に説明します。
強化学習では、エージェントが、ある環境に応じて、行動を起こし、その結果に対して報酬が与えられます。
一連の行動をして、取得した報酬の合計を行動価値と呼び、その行動価値を最大化するように学習します。

強化学習は、大まかに分けて、モンテカルロ法とTD法があり、今回はTD法の一つであるQ学習というアルゴリズムを選択します。

また、深層学習(ディープラーニング)を使って学習するので、DQN(Deep Q Network)と呼びます。

環境

google colaboratoryを使用します。
google colaboratoryとは、Googleが提供している
Jupyter Notebookをベースにしたpython環境です。

使用するメリット
・GPU/TPU(ディープラニングがGPUよりも高速)が無料で使える。
・numpy、Matplotlib、Pandas、Keras、pillowなどが予めインストール済。

使用手順は下記となります。

下記URL にアクセス
https://colab.research.google.com/notebooks/welcome.ipynb?hl=ja

「ファイル」をクリックし、「PYTHON 3 の新しいノートブック」を選択

コードを入力し、「Ctrl + Enter」 で実行できる。
([Ctrl + Shift]では、実行して次のセルに飛ぶ。)

機械学習を使用する場合には、下記手順を実施。

「ランタイム」 → 「ランタイムのタイプを変更」から
「ハードウェアアクセラレータ」で「TPU」を選択して保存。

自作パズル

下記がパズルのソースコードになります。
(少々長くなりますが、ご勘弁を!)

from PIL import Image
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt

class Puzzle():
    def __init__(self, image, side_num):
        self.width, self.height = image.size
        self.side_num = side_num

        self.piece_w = self.width / side_num
        self.piece_h = self.height / side_num

        self.piece_images = []
        self.piece_numbers = []
        self.start_piece_numbers = []

        last_index = side_num ** 2 - 1
        self.blank_number = last_index

        for piece_index in range(last_index):
            w_pos = (piece_index % side_num) * self.piece_w
            h_pos = int(piece_index / side_num) * self.piece_h

            piece_image = image.crop((int(w_pos), int(h_pos), int(w_pos + self.piece_w), int(h_pos + self.piece_h)))

            self.piece_images.append(piece_image)
            self.piece_numbers.append(piece_index)

        blank_image = Image.new('RGB', (int(self.piece_w), int(self.piece_h)))
        self.piece_images.append(blank_image)
        self.piece_numbers.append(self.blank_number)

        random_list = list(range(1000))

        while(self.is_completed()):

          random.shuffle(random_list)

          for i in random_list:
              i %= 4
              if i == 0:
                  self.action('right')
              elif i == 1:
                  self.action('left')
              elif i == 2:
                  self.action('up')            
              elif i == 3:
                  self.action('down')
          for _ in range(side_num):
              self.action('right')
              self.action('down')


        self.start_piece_numbers = self.piece_numbers[:]

    def reset(self):
        self.piece_numbers = self.start_piece_numbers[:]

    def show(self):
        back_image = Image.new('RGB', (int(self.width), int(self.height)))

        for piece_index in range(self.side_num ** 2):
            w_pos = (piece_index % self.side_num) * self.piece_w
            h_pos = int(piece_index / self.side_num) * self.piece_h            

            back_image.paste(self.piece_images[self.piece_numbers[piece_index]], (int(w_pos), int(h_pos)))

        im_list = np.asarray(back_image)

        plt.xticks(color="None")
        plt.yticks(color="None")
        plt.tick_params(length=0)
        plt.imshow(im_list)
        plt.show()

    def right_action_state(self):
        blank_index = self.piece_numbers.index(self.blank_number)

        state = []
        if (blank_index % self.side_num) != self.side_num - 1:
            state = self.piece_numbers[:]
            state[blank_index], state[blank_index + 1] = state[blank_index + 1], state[blank_index]
        return state

    def left_action_state(self):
        blank_index = self.piece_numbers.index(self.blank_number)

        state = []
        if (blank_index % self.side_num) != 0: 
            state = self.piece_numbers[:]
            state[blank_index], state[blank_index - 1] = state[blank_index - 1], state[blank_index]
        return state    

    def up_action_state(self):
        blank_index = self.piece_numbers.index(self.blank_number)

        state = []
        if int(blank_index / self.side_num) != 0:
            state = self.piece_numbers[:]
            state[blank_index], state[blank_index - self.side_num] = state[blank_index - self.side_num], state[blank_index]
        return state  

    def down_action_state(self):
        blank_index = self.piece_numbers.index(self.blank_number)

        state = []
        if int(blank_index / self.side_num) != self.side_num - 1:
            state = self.piece_numbers[:]
            state[blank_index], state[blank_index + self.side_num] = state[blank_index + self.side_num], state[blank_index]
        return state  

    def action(self, direction):
        state = []
        if direction == 'right':
            state = self.right_action_state()
            if state:
                self.piece_numbers = state

        elif direction == 'left':
            state = self.left_action_state()
            if state:
                self.piece_numbers = state

        elif direction == 'up':
            state = self.up_action_state()
            if state:
                self.piece_numbers = state

        elif direction == 'down':
            state = self.down_action_state()
            if state:
                self.piece_numbers = state

        else:
             print("invalid action")

    def is_completed(self):

        if self.piece_numbers == sorted(self.piece_numbers):
            is_completed = True
        else:
            is_completed = False

        return is_completed


    def get_reward(self):

        reward = 0

        if self.is_completed():
            reward = 100000
        else:
            for i, number in enumerate(self.piece_numbers):
                if i == number:
                    reward += 0

        return reward

    def get_state(self):
        return self.piece_numbers

    def set_state(self, state):
        self.piece_numbers = state

    def get_actionable(self):
        states = []
        state = self.right_action_state()
        if state:
            states.append(state)
        state = self.left_action_state()
        if state:
            states.append(state)
        state = self.up_action_state()
        if state:
            states.append(state)
        state = self.down_action_state()
        if state:
            states.append(state)

        return states            

下記コードで、画像をアップロードできます。
パズルの素材にする画像をお好きに選択して下さい。

from google.colab import files
uploaded = files.upload()

下記コードでパズルの作成、表示をします。
変数「side_num」の値が、パズルの一辺のピースの数になります。今回は2なので、2行2列のパズルとなります。

side_num = 2
im = Image.open(画像ファイル名)
puzzle = Puzzle(im, side_num)
puzzle.show()

下記のようなパズル※が出来上がりました。

※黒いピースをその4方にあるピースのどれかと位置を交換していき、元の絵に戻すことがゴールとなります。

DQN

下記がDQNのソースコードになります。

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam, RMSprop
from collections import deque
from keras import backend as K

class DQN:
    def __init__(self, state_size):
        self.state_size = state_size 
        self.memory = deque(maxlen=100000) 
        self.epsilon = 1.0
        model = Sequential()
        model.add(Dense(128, input_shape=(2, self.state_size), activation='relu'))
        model.add(Flatten())
        model.add(Dense(128, activation='relu'))
        model.add(Dense(128, activation='relu'))
        model.add(Dense(1, activation='linear'))
        model.compile(loss="mse", optimizer=Adam(lr=0.0001))
        self.model = model

    def choice_action(self, state, actionable):
        if self.epsilon >= random.random():
            return random.choice(actionable)
        else:
            return self.choice_best_action(state, actionable)

    def choice_best_action(self, state, actionables):
        best_actions = []
        max_action_value = -999
        for actionable in actionables:
            action_value = self.model.predict(np.array([[state, actionable]]))
            if action_value > max_action_value:
                best_actions = [actionable,]
                max_action_value = action_value
            elif action_value == max_action_value:
                best_actions.append(actionable)
        return random.choice(best_actions)

    def remember_memory(self, state, next_state, reward, next_actionables, is_completed):
        self.memory.append((state, next_state, reward, next_actionables, is_completed))

    def replay_experience(self, batch_size):
        batch_size = min(batch_size, len(self.memory))
        minibatch = random.sample(self.memory, batch_size)
        X = []
        Y = []
        for i in range(batch_size):
            state, next_state, reward, next_actionables, is_completed = minibatch[i]
            input_states = [state, next_state]
            if is_completed:
                target = reward
            else:
                next_rewards = []
                for actionable in next_actionables:
                    next_rewards.append(self.model.predict(np.array([[next_state, actionable]])))
                target = reward + 0.9 * np.amax(np.array(next_rewards))
            X.append(input_states)
            Y.append(target)
        n_X = np.array(X)
        n_Y = np.array([Y]).T
        self.model.fit(n_X, n_Y, epochs=1, verbose=0)
        if self.epsilon > 0.01:
            self.epsilon *= 0.9999

下記のソースコードを実行すると学習が始まります。

dqn = DQN(side_num**2)

episodes = 20000

times = 100

for e in range(20000):
    puzzle.reset()
    state = puzzle.get_state()
    for time in range(times):
        movables = puzzle.get_actionable()
        next_state = dqn.choice_action(state, movables)
        is_completed = puzzle.is_completed()
        puzzle.set_state(next_state)
        next_actionables = puzzle.get_actionable()
        dqn.remember_memory(state, next_state, puzzle.get_reward(), next_actionables, is_completed)
        if is_completed or time == (times - 1):
            if e % 1000 == 0:
                print("episode: {}, time {}".format(e, time))
            break
        state = next_state

    dqn.replay_experience(32)

検証

下記のソースコードで検証します。

puzzle.reset()
puzzle.show()
def test_sequence():
  movables = puzzle.get_actionable()
  state = puzzle.get_state()
  next_state = dqn.choice_best_action(state, movables)
  puzzle.set_state(next_state)
  puzzle.show()

for _ in range(100):
  test_sequence()
  if puzzle.is_completed():
    break

下記が、結果となります。

          ↓

          ↓

          ↓

          ↓

最短経路でパズルが解けていることが確認できました。

参考にさせていただいた記事

DQNで自作迷路を解く
https://qiita.com/cvusk/items/e4f5862574c25649377a

最後に

・3行3列のパズルにも挑戦しましたが、上手くできませんでした。学習時間やハイパーパラメータなどを変えて再度、挑戦してみたいです。
・TD法を使いましたが、ランダムに試行してその中で一番良いものを選ぶモンテカルロ法の方が今回の課題に適していると普通に考えて思いました。
・ディープラーニングの部分で、隠れ層を一つ減らしただけで、値が収束しなくなったので、チューニングの大事さを感じることができました。
・強化学習を業務に役立てる方法も考えていきたいです。