Pytorchで自作ゲームにDQNのAIを組み込もう!(前編)
IT技術
はじめに
Unityなどで対戦ゲームを作っている時に、「対戦相手のAIをどうするか」ということで頭を悩ませた経験はありませんか?
ゲームのAIは「⚪︎⚪︎なら××する」というような条件分岐の処理をたくさん書いたり、ゲームの状況を表す変数を用意し、その変数に従って行動を分岐させるというようにして実装することが多いと思いますが、複雑な処理を実装するのは大変です。また、ゲームの種類によっては状況の分岐が膨大になり過ぎて書き切れなかったり、そもそもどのような処理を書くべきなのかが分からなくなってしまうこともあります。
そうなると「AIの行動パターン作成を自動化できないか?」という発想が出てきます。人間の力では難しいことは機械に任せてしまおうというわけです。
しかしそんなことは本当に可能なのかと思われる方もいらっしゃるでしょう。結論から言えば可能です。
では何を使えば可能になるのでしょうか?それこそが機械学習の一分野、強化学習です。
強化学習とは何か
Wikipediaによると、強化学習の定義は以下の通りとなっております。
ある環境内における知的エージェントが、現在の状態を観測し、得られる収益(累積報酬)を最大化するために、どのような行動をとるべきかを決定する機械学習の一分野である。強化学習は、教師あり学習、教師なし学習と並んで、3つの基本的な機械学習パラダイムの一つである。
引用元:強化学習 - Wikipedia
つまり強化学習とは、行動主体(ゲームのAI)が環境(ゲームの状況)を把握した上で最適な行動は何かを決定するための技術です。
今回のケースは「AIの行動パターンの作成を自動化したい」ということですから、まさにぴったりな技術と言えるでしょう。
次に、この強化学習を使ってAIの行動パターン作成を自動化するための準備を行っていきます。
ゲームの作成
まずはゲームを開発しましょう。AIを作るといっても、学習するための環境(=ゲーム)がなければ始まりません。
今回作成したゲームの概要は以下です。なお、ゲームの作成にはpygameを使いました。
- このゲームは対戦型のシューティングゲームです。自分のエネルギー(1Pは画面下部、2Pは画面上部に表示)がなくなったら負けになります。
- このゲームでは、威力小(青色)・威力中(緑色)・威力大(赤色)の3種類の弾を撃つことができます。自分の発射した弾が相手に到達すると、その弾の威力に応じて相手のエネルギーが減ります。ただし、弾を撃つと自分のエネルギーも威力に応じて減ってしまいます。
- 自分の発射した弾が相手の弾にぶつかると、ぶつかった弾の威力に応じて弾が弱体化します。威力が同じ弾がぶつかった時は対消滅し、威力が異なる時は威力が高い弾が残り、消えた方の弾の威力に応じて弱体化します。
- エネルギーは時間経過で回復します。また、残りエネルギーに応じてバーの色は多い方から順に緑、黄色、赤色に変化します。また、無敵状態中は残りエネルギーに関わらずバーの色は青色になります。
- 画面上を移動するUFOに自分の弾を当てると、一定時間無敵状態になります。また、同様にエイリアンに自分の弾を当てると、当てた数に応じて自分の弾の威力が上がります。
実際のゲーム画面は以下のようになります。ゲーム本体のコードは主にGame.pyにて実装しております。
それでは、今回のゲーム作成において、pygameによるゲーム開発の特徴的な部分と工夫した点について説明していきます。pygameはUnityなどのゲームエンジンと比べると抽象度の高い処理をあまり提供していないため、その分の処理を自分で実装する必要があります。そういった違いにも注目していただければと思います。
① ゲームループについて
pygameでは、メインループを以下のように自分で書く必要があります。
1def main():
2 start() #初期化処理
3 while True:
4 for event in pygame.event.get():
5 if event.type == pygame.QUIT:
6 pygame.quit()
7 sys.exit()
8 pygame.display.update() #画面の更新
9 update() #ゲーム本体の処理
上記のソースコードで注目して欲しいのは3行目のwhile True:
と5行目のif文です。
まずはwhile True:
から説明します。
どんなプログラムにも共通することですが、プログラム内の処理が全て完了するとそのプログラムは基本的にすぐ終了します。このゲームプログラムも例外ではありません。しかし画面が一瞬だけ表示されてすぐ消えてしまってはゲームになりません。そのため、ゲームを表示し続けられるように何の操作をしていなくてもゲーム中は「処理が続いている」という形にする必要があります。
そこで登場するのがwhile True:
の無限ループで、これにより処理中の状態を維持しています。また同時に、無限ループの中で画面やゲーム本体の処理を実行することでプレイヤーの操作やゲームの状態の変化がすぐ反映されるようになります。
5〜7行目のif文はゲームを終了できるようにするためのものです。このプログラムは先ほど説明したwhile True:
の無限ループによって処理中の状態を維持しているのですが、while True:
はwhile文の処理継続条件を常に満たすため、そのままでは「×」ボタンのクリックなど通常の方法でのプログラム終了ができなくなってしまいます。処理中の状態を維持したいとはいっても、終了させられないのでは困ってしまいます。そこで5行目のif文によって、ユーザーが「×」ボタンをクリックするなどプログラム終了の命令が出た時はプログラムを終了できるようにしています。
②フレームレートへの対応
このゲームはPythonと必要なモジュールを導入していればどのようなPCでも遊ぶことができますが、PCと一言に言っても事務作業用の安価なものからハイスペックなゲーミングPCまで性能には幅があり、同じゲームをプレイしていてもフレームレート(1秒あたりの画面の更新回数≒メインループの実行回数)には差があります。その上、同じPCでも他に実行しているプログラムなど条件の違いによってもフレームレートには差が発生します。
そのため、フレームレートを無視して「メインループの関数が実行される度に右に10ピクセルずつ移動」というような処理をしてしまうと、PCの性能や状態によってゲームのスピード感が全く異なるものになってしまい、ゲームとして成り立たなくなってしまいます。
pygameではpygame.time.Clock
クラスのtick
関数でフレームレートの対応を行っています。tick
関数の引数にフレームレート数を指定することでフレームレートをコントロールでき、例えばtick(30)
とするとフレームレートが30を超えそうになった時に処理を遅らせることで最大フレームレートが30になるように調整します。
③画面の更新
メインループと同様に、画面の更新についても、自分で書く必要があります。以下のソースコード例をご覧ください。
1import pygame
2Width = 800
3Height = 600
4screen = pygame.display.set_mode((Width, Height)) #描画用の画面を取得
5def update()
6 global screen
7 clock = pygame.time.Clock()
8 clock.tick(60)
9 screen.fill((0,0,0,0)) # 画面全体を黒く塗りつぶす
10 # 以下、画像表示などの処理が入る
着目して欲しいのはscreen.fill((0,0,0,0))
の部分です。
Unityなどと異なり、pygameは前のフレームで描画された内容は基本そのまま残ります。もしキャラクターの画像を前のフレームから移動させていた場合、塗りつぶしの処理を明示的に呼び出さないと絵の具を指で伸ばした時のように移動した跡が残ってしまいます。
screen.fill((0,0,0,0))
の処理を呼び出すことで、前のフレームの内容を一旦全て消し、移動した跡などが出ないようにしています。
④円の衝突判定
このゲームの弾同士や弾とUFO・エイリアンの衝突判定では、衝突判定を行う2つの物体間の距離を求めた上で2つの物体を円とみなし、物体間の距離が半径の合計以下になっていれば「衝突した」と判定します。2つの物体間の距離を求める公式は以下に示す通りです。
ソースコード上では以下に示すgetCollision
関数で判定しています。
1def getCollision(x1:float, x2:float, y1:float, y2:float, radius1:float, radius2:float):
2 if (x1 - x2) ** 2 + (y1 - y2) ** 2 <= (radius1 + radius2) ** 2:
3 return True
4 else:
5 return False
皆さんの中には、上記のソースコードについて「なぜ公式をそのまま使わず、両辺を2乗した形に変形して使用しているの?」という疑問を持った方もいらっしゃるかもしれません。
平方根の計算は複雑な処理で普通の足し算や掛け算より時間がかかります。そのためゲームプログラミングでは、可能なら平方根を求めずに処理することで高速化するのが定石となっています。衝突判定では「物体間の距離が半径の合計以下か」さえ分かっていればその後の処理は進められるため、2つの物体間の正確な距離は不要です。両辺を2乗してしまっても問題はありません。
また、不等号を含む式の両辺を2乗する場合、本来なら両辺のそれぞれで正負をみて不等号の向きを変える必要があるかの確認が必要なのですが、距離と半径はともに正の数であると分かっているので、今回のケースにおいては確認せずにそのまま処理できます。
⑤ゲーム画面をAIに送信する
AIがゲームをプレイできるようにするには、ゲームの状態をAIに渡す必要があります。
ゲームの状態の渡し方には色々ありますが、
- ターゲットになる物体の種類を増やしたり、逆に減らしたりしてもAI部のソースコードに変更を加える必要がない。
- 今回作成したAIのソースコードを他のゲームに移植させる時に変更が少なくて済む。
の以上2点を鑑みて、ゲームの画面を直接入力としてAIに渡すことにしました。
pygameではpygame.surfarray.array3d()
関数を使うことでゲーム画面の画像データを整数の配列の形で取得することができます。
AIの作成
ゲームができたので、次はいよいよAIの作成です。
強化学習にも様々な手法が存在しますが、今回はDQNを採用し、DQNの実装にはPytorchを使用しました。
このセクションでも「ゲームの作成」と同様、DQNやPytorchの実装の特徴的な部分を説明していきます。
① DQNについて
今回使用する強化学習の手法であるDQNについて説明します。DQNとは、Deep Q Networkの頭文字を取ったもので、強化学習の一手法である「Q学習」にディープラーニングを組み合わせた手法です。
Q学習にディープラーニングを組み合わせることのメリットは、複雑な環境にも対応可能なことです。Q学習は手法の特性上、環境において起こりうる状態と実行可能な行動全てについてのデータを持っておく必要があります。ごく単純な環境ならこれでも問題ありませんが、状態と行動の組み合わせが膨大になってくると全てを網羅するのが困難になります。
例えば今回作成したゲームの場合、プレイヤーが実行可能な行動は移動が右に移動・左に移動・移動しないの3通り、弾の発射が威力小・中・大の3種類の弾をそれぞれ撃つ・撃たないの2通りずつあり、2種類同時撃ち・3種類同時撃ちの両方が可能なため、3×2×2×2 = 24 通りの行動があります。しかもその上、自分と相手の残りエネルギー、無敵状態か否か、さらには画面上の弾やUFO、エイリアンの位置と個数について、ありうる組み合わせが無数に考えられるため、状態と行動全てについてデータを持っておくのはほぼ不可能だといえます。
そこで登場するのがディープラーニングです。状態と行動全てについてのデータを保持する代わりに、ディープラーニングによってそれらを近似することによって複雑な問題でもQ学習を適用可能になります。
DQNでは状態と行動全てについてのデータをディープラーニングでただ近似するだけでなく、以下のような工夫によって性能向上を図っています。
1. 経験再生 (Experience Replay)
Q学習では、「現在の状態」「状態から判断して実際に取った行動」「行動した結果の得点(報酬)」「行動後の状態」の4組のデータを利用します。この4組のデータを収集して学習を進めていくのですが、収集したデータから順に学習させようとすると、ゲーム開始時のデータ、ゲーム開始から1秒後のデータ、ゲーム開始から2秒後のデータ、...というように、連続した強い相関のあるデータを学習することになり、データが偏ることになります。データが偏ると、学習が安定しないなどの問題が生じてしまいます。
そこで登場するのが経験再生です。経験再生では、得られたデータをそのまま使わずに一旦保存しておき、学習する時は保存したデータからランダムに何個か取り出して学習します。こうすることで学習に使うデータの偏りが少なくなり、学習が安定しやすくなります。
実装の際はPython標準ライブラリのcollections.deque
を使い、古いデータが自動で削除されるようにした上で学習用のデータを一時保存します。
2. ターゲットネットワーク
Q学習では、行動価値関数(ある状態である行動を取った時の報酬の期待値を求める関数)を更新することで学習を進めます。しかし、行動価値関数の更新は行動価値関数そのもので計算した値を使用します。そのため、更新中のニューラルネットワークを使ってニューラルネットワーク更新用の計算を行う形になり、これも学習の不安定化の原因となります。
そこで、ニューラルネットワーク更新用の計算を行うために別途同じ構造のニューラルネットワークをもう1個用意します。これをターゲットネットワークと言います。実際に学習するネットワークを更新する時はターゲットネットワークに行動価値関数を計算してもらい、その値で更新します。Q学習を進めるため、ターゲットネットワークは一定の間隔で実際に学習するネットワークの重みに更新します。
② 入力データについて
「ゲームの作成」の「⑤ゲーム画面をAIに送信する」で説明したように、今回はゲームの画面を直接入力としてAIに渡しますが、取得したゲーム画面をそのままAIに送信しているわけではなく、以下のような加工を行なってから入力しています。
1. 画像のリサイズ
ゲーム画面は幅800ピクセル、高さ600ピクセルですが、このまま入力として使用するには大きすぎます。入力する画像サイズが大きすぎると学習や推論時に時間がかかってしまいますし、経験再生のために保存するデータのサイズも増えてメモリを圧迫してしまいます。そのため、ゲームの画面を入力として渡す時は縮小するのが一般的です。今回はOpenCVというライブラリのresize
関数を使用して幅100ピクセル、高さ75ピクセルに縮小していますがresize
には2つ注意点があります。どちらも従わなかった場合はエラーを出して停止するのですが、エラーメッセージが分かりづらいので気をつけましょう。
① 入力するデータ型
resize
関数で縮小する画像データの型はnumpy.uint8
である必要があります。これ以外の型ではエラーになるようです。resize
関数を使用する時は入力するデータの型を確認し、もしnumpy.uint8
でない場合はastype
関数などでnumpy.uint8
に変換してから使用するようにしましょう。
② 画像データ配列の形状
resize
関数で縮小する画像データの配列の形状は(幅, 高さ, 画像のチャンネル数)でないといけません。幅と高さは逆になっていても良いですが、画像のチャンネル数が最初に来ていたりするとエラーになります。画像データの配列の形状が(幅, 高さ, 画像のチャンネル数)になっていない時はリサイズする前にreshape
関数などで整形しておきましょう。
2. 画像の正規化
ディープラーニングに限らず、機械学習のAIは桁数の多い項目の影響を受けやすいです。今回はゲーム画面以外のデータを入力に使用していませんが、他のデータも入力に入れる場合にデータのスケールが異なっていると、桁数が多い項目が過大評価される(あるいは桁数の少ない項目が過小評価される)原因になります。
そこで必要になってくるのが正規化です。正規化の処理を行うことによって数値の最大値が1、最小値が0に統一され、上記のような問題が解消されます。正規化の処理は以下のように実装しています。
1g_min = state.min()
2g_max = state.max()
3# 正規化時のゼロ除算対策
4# (g_max - g_min) が0の時(最大値と最小値が同じ時)はg_minが0なら0、そうでないなら1にする
5if (g_max - g_min) == 0:
6 state[:] = 0 if g_min != 0 else 1
7else:
8 state = (state - g_min) / (g_max - g_min)
state
はゲーム画面の画像データのnumpy配列です。データ内の最小値・最大値は(numpy配列).min()
・(numpy配列).max()
でそれぞれ求められます。引数を指定しない場合(numpy配列).min()
や(numpy配列).max()
の返り値は1個の数値になりますが、numpy配列にはブロードキャストという機能があり、1個の数値変数で計算する時と同様の書き方で配列の各要素に同じ計算を一斉に行うことができます。これにより計算の記述がシンプルになるだけでなく、計算処理の高速化も期待できます。
基本的には正規化の公式である「(データ - データの最小値) ÷ (データの最大値 - データの最小値)」で計算していますが、配列内の数値が全て同じでデータの最大値 - データの最小値が0になってしまう時(画像が全て1色の時)ではゼロ除算になってしまいます。そのため、データの最大値 - データの最小値が0の時は通常の正規化の計算を行う代わりにデータの最小値をチェックして、最小値が0なら0、そうでないなら1で全て埋めるようにしています。
3. 画像の連続フレーム化
DQNをはじめとして、強化学習のモデルでは最適な行動を決めるために必要な全ての情報が入力である「現在の状態」から取得できる必要があります。しかし、ゲーム画面1フレームからだけでは必要な全ての情報が取得できません。
説明のために、ここで実際のゲーム画面を以下に再掲します。
DQNのモデルはゲームのルールについて何の知識も持っていないため、ゲーム画面1フレームから弾やUFO、エイリアンの位置は分かりますがこれらの進行方向を判断することができません。そこでゲーム画面の連続する数フレームをまとめて入力として使用することにより、画面内の物体の進行方向を入力から判断できるようにします。今回は連続4フレームをまとめて入力にしています。
ゲーム画面の連続フレームを入力にするため、こちらでもPython標準ライブラリのcollections.deque
を使い、常に直前の連続4フレームを取得できるようにします。
ここで一つテクニックを紹介します。実は、collections.deque
はnumpy.array()
関数の引数に直接入れることにより、形状そのままでnumpy配列に変換することができます。numpy配列に変換する関係上deque内の各要素の型や配列の形状(配列を格納している場合)は全て一致させるかnumpy.array()
関数の引数dtypeにobjectを指定する(全て一致させることが不可能な場合)必要がありますが、便利なテクニックなので覚えておきましょう。
ここまでで、AIに入力するデータの準備ができました。次回はいよいよAIを作成します。使用するモデルや、学習する方法について解説していきますのでお楽しみに。
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
TypeScript、Unity、Python、Goが得意なエンジニア。 最近はTypeScript+next.jsでの開発が多いです。