GANで本物のように精巧な画像生成モデルを作ってみた【Pytorch】
IT技術
GANとは?
あなたは下の画像が「機械が生成した実在しない人の顔写真」か「実在する人の顔写真」かを見分けられますか?
近年AI技術が発達し、「機械が考える」レベルが現実的になってきています。
また、CNN(Convolution Neural Network)の登場により、画像分野はさらに目覚ましい発展を遂げ、未知の画像を生成する生成モデルなるものが提案されました。
その中でも、ここ最近では「GAN(Generative Adversarial Networks)」とよばれる技術が一躍注目を浴びています。
実は、先程の画像も、GANが生成した画像であり、現実に存在する本物の人ではありません。
ただ、機械が作ったとは思えないほど精巧さで、判別は難しいです。
さて今回は、そんなGANの仕組みの解説と実践を行ってみたいと思います!
GANのネットワーク構造
GANのネットワーク構造は2つに分かれます。
- 実際に画像を生成するGenerator
- 画像が与えられたとき、それが機械から生成された画像かどうかを判定する識別器であるDiscriminator
(※ニューラルネットは識別問題を解くのが得意です)
Generator:画像生成器
「Generator」のネットワーク構造は、以下の画像のようになっています。
ノイズから顔画像を生成します。
Discriminator:本物or偽物かの識別機
「Discriminator」のネットワーク構造は、このようになっています。
画像を入力として、その画像が「本物」か「偽物」かどうかを判定します。
ノイズについて
ノイズは、いわゆる正規分布からサンプリングされます。
この分布の次元は、一般に画像空間よりも低次元です。
(※画像空間は、縦×横のピクセル数だけ次元が存在します)
2次元で考えた場合
例えば、2次元で考えたとしましょう。
そうしたとき上の図のように、「二次元上の点」が「ひとつの画像」を表すように画像空間を埋め込みます。
それぞれの青丸の点は、正規分布からサンプリングされた点です。
このように、「Generator」の入力部のノイズは二次元上の点を表しています。(今回の場合)
そして出力は、その二次元上の点に埋め込まれた画像空間上の画像を生成することになります。
本来のノイズは、もっと高次元です。
さすがに、顔写真を二次元空間に落とし込むのは無理があります。
GANの損失関数
それでは、まずは学習の話です。
「Discriminator」の学習では、本物を本物、偽物を偽物と見分ける能力が必要です。
「Generator」から生成された画像に対しては、偽物ラベルの「0」を出力。
データセットの画像に対しては、本物ラベルの「1」を出力するようにします。
こうすることで「Discriminator」は、本物画像と偽物画像それぞれの特徴を学習し、本物か偽物かを見分ける能力を伸ばしていきます。
特徴を学習するための損失関数
そのため。損失関数は以下のようになります。
本物の顔画像 | |
ノイズ | |
そして「Generator」の学習では、偽物画像を本物だと騙したいため、「Generator」から出力された画像を入力として「Discriminator」に入れた時の出力が「1」に近づくようにします。
偽物を本物と思わせるための損失関数
そのため、損失関数は以下のようになります。
本物の顔画像 | |
ノイズ | |
このようにして「Generator」は、本物画像と偽物画像の特徴を学習した「Discriminator」を、さらに騙すように偽物画像をより本物画像へと近づけていきます。
これが「Adversarial(敵対的)」という由来です。
GANの実験結果
では、実際にGANの学習を行ってみます!
そして、epoch 1~14までの結果を見ていきましょう!
使ったデータセットは、以下のリンクのzipフォルダを解凍した画像群です。
かなり枚数が多いので、その一部の20万枚だけ使いました。
【Celeb-aデータセット】
https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
発生させるノイズは固定で、どのように顔が作られていくかを見ることができます。
epoch = 1
epoch = 2(ぼんやりと顔の形が)
epoch = 3(水彩画みたいな顔が出来始める)
epoch = 4(サンダーバードみたいな顔が出来はじめる)
epoch = 5(あまり変化なし、大体みんな同じ顔)
epoch = 6(輪郭が綺麗に)
epoch = 7(ちらほらいい結果が)
epoch = 8(割愛)
epoch = 9
epoch = 10
epoch = 11
epoch = 12
epoch = 13(斜め顔は苦手っぽい?)
epoch = 14
結果に対する感想
顔画像の生成が出来ましたが、なにより生成データの多様性があるのが素晴らしいと思いました。
笑っている人もいれば、口を閉じている人もいるし、顔の向きも特徴付けられています。
それに肌の色なんかも、よく再現できているのではないでしょうか。
GAN は学習自体が難しいのですが、たった数epochで成果が目で見えるので、総合的にみると楽になるな感じました。
そして、学習は「データ数が命」ということに気付かされました。
実験に用いたソースコード(Pytorch)
では最後に、実験に用いたソースコードを紹介したいと思います。
まず、Pytorch が入っていない方はコチラ!
【Pytorch公式サイト】
https://pytorch.org/
こちらから、自分の環境にあったインストールコードを terminal に入力します。
ソースコード
1# -*- coding: utf-8 -*-
2
3
4import os
5import random
6import numpy as np
7import torch.nn as nn
8import torch.optim as optim
9import torch.utils.data
10import torchvision.datasets as dset
11import torchvision.transforms as transforms
12import torchvision.utils as vutils
13import matplotlib.pyplot as plt
14import copy
15from collections import OrderedDict
16# Initial_setting
17workers = 1
18batch_size=64
19nz = 100
20nch_g = 64
21nch_d = 64
22n_epoch = 10000
23lr = 0.001
24beta1 = 0.5
25outf = './result_lsgan'
26display_interval = 100
27save_fake_image_interval = 1500
28plt.rcParams['figure.figsize'] = 10, 6
29
30
31try:
32 os.makedirs(outf, exist_ok=True)
33except OSError as error:
34 print(error)
35 pass
36
37random.seed(0)
38np.random.seed(0)
39torch.manual_seed(0)
40
41
42def weights_init(m):
43 classname = m.__class__.__name__
44 if classname.find('Conv') != -1:
45 m.weight.data.normal_(0.0, 0.02)
46 m.bias.data.fill_(0)
47 elif classname.find('Linear') != -1:
48 m.weight.data.normal_(0.0, 0.02)
49 m.bias.data.fill_(0)
50 elif classname.find('BatchNorm') != -1:
51 m.weight.data.normal_(1.0, 0.02)
52 m.bias.data.fill_(0)
53
54
55class Generator(nn.Module):
56 def __init__(self, nz=100, nch_g=64, nch=3):
57 super(Generator, self).__init__()
58 self.layers = nn.ModuleDict({
59 'layer0': nn.Sequential(
60 nn.ConvTranspose2d(nz, nch_g * 8, 4, 1, 0),
61 nn.BatchNorm2d(nch_g * 8),
62 nn.ReLU()
63 ), # (100, 1, 1) -> (512, 4, 4)
64 'layer1': nn.Sequential(
65 nn.ConvTranspose2d(nch_g * 8, nch_g * 4, 4, 2, 1),
66 nn.BatchNorm2d(nch_g * 4),
67 nn.ReLU()
68 ), # (512, 4, 4) -> (256, 8, 8)
69 'layer2': nn.Sequential(
70 nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 4, 2, 1),
71 nn.BatchNorm2d(nch_g * 2),
72 nn.ReLU()
73 ), # (256, 8, 8) -> (128, 16, 16)
74
75 'layer3': nn.Sequential(
76 nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
77 nn.BatchNorm2d(nch_g),
78 nn.ReLU()
79 ), # (128, 16, 16) -> (64, 32, 32)
80 'layer4': nn.Sequential(
81 nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
82 nn.Tanh()
83 ) # (64, 32, 32) -> (3, 64, 64)
84 })
85
86 def forward(self, z):
87 for layer in self.layers.values():
88 z = layer(z)
89 return z
90
91
92class Discriminator(nn.Module):
93 def __init__(self, nch=3, nch_d=64):
94 super(Discriminator, self).__init__()
95 self.layers = nn.ModuleDict({
96 'layer0': nn.Sequential(
97 nn.Conv2d(nch, nch_d, 4, 2, 1),
98 nn.BatchNorm2d(nch_d),
99 nn.LeakyReLU(negative_slope=0.2)
100 ), # (3, 64, 64) -> (64, 32, 32)
101 'layer1': nn.Sequential(
102 nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
103 nn.BatchNorm2d(nch_d*2),
104 nn.LeakyReLU(negative_slope=0.2)
105 ), # (64, 32, 32) -> (128, 16, 16)
106 'layer2': nn.Sequential(
107 nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
108 nn.BatchNorm2d(nch_d*4),
109 nn.LeakyReLU(negative_slope=0.2)
110 ), # (128, 16, 16) -> (256, 8, 8)
111 'layer3': nn.Sequential(
112 nn.Conv2d(nch_d * 4, nch_d * 8, 4, 2, 1),
113 nn.BatchNorm2d(nch_d*8),
114 nn.LeakyReLU(negative_slope=0.2)
115 ), # (256, 8, 8) -> (512, 4, 4)
116 'layer4':nn.Sequential( nn.Conv2d(nch_d * 8, 1, 4, 1, 0)
117 # (512, 4, 4) -> (1, 1, 1)
118 #勾配消失を防ぐためにSigmoidは使わない
119 )
120 })
121
122 def forward(self, x):
123 for layer in self.layers.values():
124 x = layer(x)
125 x = x.squeeze()
126
127 return x
128
129
130
131
132
133def main():
134 #パスには入っている画像データセットの一個上の改装を
135 #例:もし your_home/Face_Datasets/Japanese/000.jpgのような階層になっていれば
136 #root = your_home/Face_Datasetsとする
137 dataset = dset.ImageFolder(root='C:/Users/User/Downloads/img_align_celeba/',
138 transform=transforms.Compose([
139 transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)),
140 transforms.RandomHorizontalFlip(),
141 transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
142 transforms.ToTensor(),
143 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
144 ]))
145
146 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
147 shuffle=True, num_workers=int(workers))
148
149 #GPUがあるならGPUデバイスを作動
150 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
151
152 #Generatorを定義
153 netG = Generator(nz=nz, nch_g=nch_g).to(device)
154 #ネットワークパラメ-タをランダムに初期化
155 netG.apply(weights_init)
156
157 #Discriminatorを定義
158 netD = Discriminator(nch_d=nch_d).to(device)
159
160 #ネットワークパラメ-タをランダムに初期化
161 netD.apply(weights_init)
162
163 #損失関数を二乗誤差に設定
164 criterion = nn.MSELoss()
165
166 #それそれのパラメータ更新用のoptimizerを定義
167 optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
168 optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
169
170 Loss_D_list, Loss_G_list = [], []
171
172 fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device) # save_fake_image用ノイズ(固定)
173 for epoch in range(n_epoch):
174
175 for itr, data in enumerate(dataloader):
176
177
178 real_image = data[0].to(device) # 本物画像
179
180 sample_size = real_image.size(0) # 画像枚数
181 noise = torch.randn(sample_size, nz, 1, 1, device=device) # 入力ベクトル生成(正規分布ノイズ)
182 real_target = torch.full((sample_size,), random.uniform(1, 1), device=device) # 本物ラベル
183 fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル
184
185
186 #Generator→Discriminatorの順にパラメータ更新
187
188 #--------- Update Generator ----------
189 netG.zero_grad()
190 fake_image = netG(noise) # Generatorから得られた偽画像
191 output = netD(fake_image) # 更新した Discriminatorで、偽物画像を判定
192 errG = criterion(output,real_target) # 偽物画像の判定結果と本物ラベルとの二乗誤差
193 errG.backward(retain_graph = True) # 誤差逆伝播
194 D_G_z2 = output.mean().item() # outputの平均 D_G_z2 を計算(後でログ出力に使用)
195
196 optimizerG.step() # Generatorのパラメータ更新
197
198 #--------- Update Discriminaator ----------
199 netD.zero_grad() # 勾配の初期化
200 fake_image = netG(noise) # Generatorから得られた偽画像
201 output = netD(real_image) # Discriminatorが行った、本物画像の判定結果
202 errD_real = criterion(output,real_target) # 本物画像の判定結果と本物ラベルとの二乗誤差
203 D_x = output.mean().item() # outputの平均 D_x を計算(後でログ出力に使用)
204
205 output = netD(fake_image.detach()) # Discriminatorが行った、偽物画像の判定結果
206
207 errD_fake = criterion(output,fake_target) # 偽物画像の判定結果と偽物画像との二乗誤差
208 D_G_z1 = output.mean().item() # outputの平均 D_G_z1 を計算(後でログ出力に使用)
209
210 errD = errD_real + errD_fake # Discriminator 全体の損失
211 errD.backward(retain_graph = True) # 誤差逆伝播
212 optimizerD.step() # Discriminatoeのパラメーター更新
213
214
215
216
217
218
219 #定期的に損失を表示
220 if itr % 5 == 0:
221 print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
222 .format(epoch + 1, n_epoch,
223 itr + 1, len(dataloader),
224 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
225
226 Loss_D_list.append(errD.item())
227 Loss_G_list.append(errG.item())
228
229 #定期的に画像を保存
230 if (itr + 1) % 50 == 0:
231
232 fake_image = netG(fixed_noise)
233 vutils.save_image(fake_image.detach(), './GAN/{:03d}random_{:03d}.png'.format(itr,epoch + 1),
234 normalize=True, nrow=8)
235
236
237
238
239 # plot graph
240 plt.figure()
241 plt.plot(range(len(Loss_D_list)), Loss_D_list, color='blue', linestyle='-', label='Loss_D')
242 plt.plot(range(len(Loss_G_list)), Loss_G_list, color='red', linestyle='-', label='Loss_G')
243 plt.legend()
244 plt.xlabel('iter (*100)')
245 plt.ylabel('loss')
246 plt.title('Loss_D and Loss_G')
247 plt.grid()
248 plt.savefig('Loss_graph.png')
249
250if __name__ == '__main__':
251 main()
さいごに
今回は、画像生成モデルについてお話ししました。
画像生成モデルは、「GAN」以外にも様々存在し、「GAN」自体の派生形も多く存在します。
その研究は、盛んに行われており、多くの研究者が血肉を割いて実験しています。
この分野での成長は凄まじく、今までの常識だった事がたった数ヶ月で覆される可能性もあります。
そのため私たちエンジニアは、日々新しい知識を吸収し自分の手で動かし、使えるようにならなければなりません。
もしかしたら、次に技術革新を起こすのは、あなたかもしれません!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...
参考
【GANの元論文】
https://arxiv.org/abs/1406.2661
【LSGAN】
https://arxiv.org/abs/1611.04076
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit