
GANで本物のように精巧な画像生成モデルを作ってみた【Pytorch】
2021.12.20
GANとは?
あなたは下の画像が「機械が生成した実在しない人の顔写真」か「実在する人の顔写真」かを見分けられますか?

近年AI技術が発達し、「機械が考える」レベルが現実的になってきています。
また、CNN(Convolution Neural Network)の登場により、画像分野はさらに目覚ましい発展を遂げ、未知の画像を生成する生成モデルなるものが提案されました。
その中でも、ここ最近では「GAN(Generative Adversarial Networks)」とよばれる技術が一躍注目を浴びています。
実は、先程の画像も、GANが生成した画像であり、現実に存在する本物の人ではありません。
ただ、機械が作ったとは思えないほど精巧さで、判別は難しいです。
さて今回は、そんなGANの仕組みの解説と実践を行ってみたいと思います!
GANのネットワーク構造
GANのネットワーク構造は2つに分かれます。
- 実際に画像を生成するGenerator
- 画像が与えられたとき、それが機械から生成された画像かどうかを判定する識別器であるDiscriminator
(※ニューラルネットは識別問題を解くのが得意です)
Generator:画像生成器
「Generator」のネットワーク構造は、以下の画像のようになっています。
ノイズから顔画像を生成します。
Discriminator:本物or偽物かの識別機
「Discriminator」のネットワーク構造は、このようになっています。
画像を入力として、その画像が「本物」か「偽物」かどうかを判定します。
ノイズについて
ノイズは、いわゆる正規分布からサンプリングされます。
この分布の次元は、一般に画像空間よりも低次元です。
(※画像空間は、縦×横のピクセル数だけ次元が存在します)
2次元で考えた場合
例えば、2次元で考えたとしましょう。
そうしたとき上の図のように、「二次元上の点」が「ひとつの画像」を表すように画像空間を埋め込みます。
それぞれの青丸の点は、正規分布からサンプリングされた点です。
このように、「Generator」の入力部のノイズは二次元上の点を表しています。(今回の場合)
そして出力は、その二次元上の点に埋め込まれた画像空間上の画像を生成することになります。
本来のノイズは、もっと高次元です。
さすがに、顔写真を二次元空間に落とし込むのは無理があります。
GANの損失関数
それでは、まずは学習の話です。
「Discriminator」の学習では、本物を本物、偽物を偽物と見分ける能力が必要です。
「Generator」から生成された画像に対しては、偽物ラベルの「0」を出力。
データセットの画像に対しては、本物ラベルの「1」を出力するようにします。
こうすることで「Discriminator」は、本物画像と偽物画像それぞれの特徴を学習し、本物か偽物かを見分ける能力を伸ばしていきます。
特徴を学習するための損失関数
そのため。損失関数は以下のようになります。
$$\frac{1}{2}(D(x)-1)^2+\frac{1}{2}(D(G(z)))^2$$
\(x\) | 本物の顔画像 |
\(z\) | ノイズ |
\(D\) | \(Discriminator\) |
\(G\) | \(Generator\) |
そして「Generator」の学習では、偽物画像を本物だと騙したいため、「Generator」から出力された画像を入力として「Discriminator」に入れた時の出力が「1」に近づくようにします。
偽物を本物と思わせるための損失関数
そのため、損失関数は以下のようになります。
$$\frac{1}{2}(D(G(z))-1)^2$$
\(x\) | 本物の顔画像 |
\(z\) | ノイズ |
\(D\) | \(Discriminator\) |
\(G\) | \(Generator\) |
このようにして「Generator」は、本物画像と偽物画像の特徴を学習した「Discriminator」を、さらに騙すように偽物画像をより本物画像へと近づけていきます。
これが「Adversarial(敵対的)」という由来です。
GANの実験結果
では、実際にGANの学習を行ってみます!
そして、epoch 1~14までの結果を見ていきましょう!
使ったデータセットは、以下のリンクのzipフォルダを解凍した画像群です。
かなり枚数が多いので、その一部の20万枚だけ使いました。
【Celeb-aデータセット】
https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
発生させるノイズは固定で、どのように顔が作られていくかを見ることができます。
epoch = 1

epoch = 2(ぼんやりと顔の形が)

epoch = 3(水彩画みたいな顔が出来始める)

epoch = 4(サンダーバードみたいな顔が出来はじめる)

epoch = 5(あまり変化なし、大体みんな同じ顔)

epoch = 6(輪郭が綺麗に)

epoch = 7(ちらほらいい結果が)

epoch = 8(割愛)

epoch = 9

epoch = 10

epoch = 11

epoch = 12

epoch = 13(斜め顔は苦手っぽい?)

epoch = 14

結果に対する感想
顔画像の生成が出来ましたが、なにより生成データの多様性があるのが素晴らしいと思いました。
笑っている人もいれば、口を閉じている人もいるし、顔の向きも特徴付けられています。
それに肌の色なんかも、よく再現できているのではないでしょうか。
GAN は学習自体が難しいのですが、たった数epochで成果が目で見えるので、総合的にみると楽になるな感じました。
そして、学習は「データ数が命」ということに気付かされました。
実験に用いたソースコード(Pytorch)
では最後に、実験に用いたソースコードを紹介したいと思います。
まず、Pytorch が入っていない方はコチラ!
【Pytorch公式サイト】
https://pytorch.org/
こちらから、自分の環境にあったインストールコードを terminal に入力します。
ソースコード
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | # -*- coding: utf-8 -*- import os import random import numpy as np import torch.nn as nn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils import matplotlib.pyplot as plt import copy from collections import OrderedDict # Initial_setting workers = 1 batch_size=64 nz = 100 nch_g = 64 nch_d = 64 n_epoch = 10000 lr = 0.001 beta1 = 0.5 outf = './result_lsgan' display_interval = 100 save_fake_image_interval = 1500 plt.rcParams['figure.figsize'] = 10, 6 try: os.makedirs(outf, exist_ok=True) except OSError as error: print(error) pass random.seed(0) np.random.seed(0) torch.manual_seed(0) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) class Generator(nn.Module): def __init__(self, nz=100, nch_g=64, nch=3): super(Generator, self).__init__() self.layers = nn.ModuleDict({ 'layer0': nn.Sequential( nn.ConvTranspose2d(nz, nch_g * 8, 4, 1, 0), nn.BatchNorm2d(nch_g * 8), nn.ReLU() ), # (100, 1, 1) -> (512, 4, 4) 'layer1': nn.Sequential( nn.ConvTranspose2d(nch_g * 8, nch_g * 4, 4, 2, 1), nn.BatchNorm2d(nch_g * 4), nn.ReLU() ), # (512, 4, 4) -> (256, 8, 8) 'layer2': nn.Sequential( nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 4, 2, 1), nn.BatchNorm2d(nch_g * 2), nn.ReLU() ), # (256, 8, 8) -> (128, 16, 16) 'layer3': nn.Sequential( nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1), nn.BatchNorm2d(nch_g), nn.ReLU() ), # (128, 16, 16) -> (64, 32, 32) 'layer4': nn.Sequential( nn.ConvTranspose2d(nch_g, nch, 4, 2, 1), nn.Tanh() ) # (64, 32, 32) -> (3, 64, 64) }) def forward(self, z): for layer in self.layers.values(): z = layer(z) return z class Discriminator(nn.Module): def __init__(self, nch=3, nch_d=64): super(Discriminator, self).__init__() self.layers = nn.ModuleDict({ 'layer0': nn.Sequential( nn.Conv2d(nch, nch_d, 4, 2, 1), nn.BatchNorm2d(nch_d), nn.LeakyReLU(negative_slope=0.2) ), # (3, 64, 64) -> (64, 32, 32) 'layer1': nn.Sequential( nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1), nn.BatchNorm2d(nch_d*2), nn.LeakyReLU(negative_slope=0.2) ), # (64, 32, 32) -> (128, 16, 16) 'layer2': nn.Sequential( nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1), nn.BatchNorm2d(nch_d*4), nn.LeakyReLU(negative_slope=0.2) ), # (128, 16, 16) -> (256, 8, 8) 'layer3': nn.Sequential( nn.Conv2d(nch_d * 4, nch_d * 8, 4, 2, 1), nn.BatchNorm2d(nch_d*8), nn.LeakyReLU(negative_slope=0.2) ), # (256, 8, 8) -> (512, 4, 4) 'layer4':nn.Sequential( nn.Conv2d(nch_d * 8, 1, 4, 1, 0) # (512, 4, 4) -> (1, 1, 1) #勾配消失を防ぐためにSigmoidは使わない ) }) def forward(self, x): for layer in self.layers.values(): x = layer(x) x = x.squeeze() return x def main(): #パスには入っている画像データセットの一個上の改装を #例:もし your_home/Face_Datasets/Japanese/000.jpgのような階層になっていれば #root = your_home/Face_Datasetsとする dataset = dset.ImageFolder(root='C:/Users/User/Downloads/img_align_celeba/', transform=transforms.Compose([ transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=int(workers)) #GPUがあるならGPUデバイスを作動 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #Generatorを定義 netG = Generator(nz=nz, nch_g=nch_g).to(device) #ネットワークパラメ-タをランダムに初期化 netG.apply(weights_init) #Discriminatorを定義 netD = Discriminator(nch_d=nch_d).to(device) #ネットワークパラメ-タをランダムに初期化 netD.apply(weights_init) #損失関数を二乗誤差に設定 criterion = nn.MSELoss() #それそれのパラメータ更新用のoptimizerを定義 optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5) Loss_D_list, Loss_G_list = [], [] fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device) # save_fake_image用ノイズ(固定) for epoch in range(n_epoch): for itr, data in enumerate(dataloader): real_image = data[0].to(device) # 本物画像 sample_size = real_image.size(0) # 画像枚数 noise = torch.randn(sample_size, nz, 1, 1, device=device) # 入力ベクトル生成(正規分布ノイズ) real_target = torch.full((sample_size,), random.uniform(1, 1), device=device) # 本物ラベル fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル #Generator→Discriminatorの順にパラメータ更新 #--------- Update Generator ---------- netG.zero_grad() fake_image = netG(noise) # Generatorから得られた偽画像 output = netD(fake_image) # 更新した Discriminatorで、偽物画像を判定 errG = criterion(output,real_target) # 偽物画像の判定結果と本物ラベルとの二乗誤差 errG.backward(retain_graph = True) # 誤差逆伝播 D_G_z2 = output.mean().item() # outputの平均 D_G_z2 を計算(後でログ出力に使用) optimizerG.step() # Generatorのパラメータ更新 #--------- Update Discriminaator ---------- netD.zero_grad() # 勾配の初期化 fake_image = netG(noise) # Generatorから得られた偽画像 output = netD(real_image) # Discriminatorが行った、本物画像の判定結果 errD_real = criterion(output,real_target) # 本物画像の判定結果と本物ラベルとの二乗誤差 D_x = output.mean().item() # outputの平均 D_x を計算(後でログ出力に使用) output = netD(fake_image.detach()) # Discriminatorが行った、偽物画像の判定結果 errD_fake = criterion(output,fake_target) # 偽物画像の判定結果と偽物画像との二乗誤差 D_G_z1 = output.mean().item() # outputの平均 D_G_z1 を計算(後でログ出力に使用) errD = errD_real + errD_fake # Discriminator 全体の損失 errD.backward(retain_graph = True) # 誤差逆伝播 optimizerD.step() # Discriminatoeのパラメーター更新 #定期的に損失を表示 if itr % 5 == 0: print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}' .format(epoch + 1, n_epoch, itr + 1, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) Loss_D_list.append(errD.item()) Loss_G_list.append(errG.item()) #定期的に画像を保存 if (itr + 1) % 50 == 0: fake_image = netG(fixed_noise) vutils.save_image(fake_image.detach(), './GAN/{:03d}random_{:03d}.png'.format(itr,epoch + 1), normalize=True, nrow=8) # plot graph plt.figure() plt.plot(range(len(Loss_D_list)), Loss_D_list, color='blue', linestyle='-', label='Loss_D') plt.plot(range(len(Loss_G_list)), Loss_G_list, color='red', linestyle='-', label='Loss_G') plt.legend() plt.xlabel('iter (*100)') plt.ylabel('loss') plt.title('Loss_D and Loss_G') plt.grid() plt.savefig('Loss_graph.png') if __name__ == '__main__': main() |
さいごに
今回は、画像生成モデルについてお話ししました。
画像生成モデルは、「GAN」以外にも様々存在し、「GAN」自体の派生形も多く存在します。
その研究は、盛んに行われており、多くの研究者が血肉を割いて実験しています。
この分野での成長は凄まじく、今までの常識だった事がたった数ヶ月で覆される可能性もあります。
そのため私たちエンジニアは、日々新しい知識を吸収し自分の手で動かし、使えるようにならなければなりません。
もしかしたら、次に技術革新を起こすのは、あなたかもしれません!
こちらの記事もオススメ!
参考
【GANの元論文】
https://arxiv.org/abs/1406.2661
【LSGAN】
https://arxiv.org/abs/1611.04076
書いた人はこんな人

- 「好きを仕事にするエンジニア集団」の(株)ライトコードです!
ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。
システム開発依頼・お見積もり大歓迎!
また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。
以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit
ライトコードの日常12月 1, 2023ライトコードクエスト〜東京オフィス歴史編〜
ITエンタメ10月 13, 2023Netflixの成功はレコメンドエンジン?
ライトコードの日常8月 30, 2023退職者の最終出社日に密着してみた!
ITエンタメ8月 3, 2023世界初の量産型ポータブルコンピュータを開発したのに倒産!?アダム・オズボーン