• トップ
  • ブログ一覧
  • GANで本物のように精巧な画像生成モデルを作ってみた【Pytorch】
  • GANで本物のように精巧な画像生成モデルを作ってみた【Pytorch】

    広告メディア事業部広告メディア事業部
    2020.03.31

    IT技術

    GANとは?

    あなたは下の画像が「機械が生成した実在しない人の顔写真」か「実在する人の顔写真」かを見分けられますか?

    近年AI技術が発達し、「機械が考える」レベルが現実的になってきています。

    また、CNN(Convolution Neural Network)の登場により、画像分野はさらに目覚ましい発展を遂げ、未知の画像を生成する生成モデルなるものが提案されました。

    その中でも、ここ最近では「GAN(Generative Adversarial Networks)」とよばれる技術が一躍注目を浴びています。

    実は、先程の画像も、GANが生成した画像であり、現実に存在する本物の人ではありません

    ただ、機械が作ったとは思えないほど精巧さで、判別は難しいです。

    さて今回は、そんなGANの仕組みの解説と実践を行ってみたいと思います!

    GANのネットワーク構造

    GANのネットワーク構造は2つに分かれます。

    1. 実際に画像を生成するGenerator
    2. 画像が与えられたとき、それが機械から生成された画像かどうかを判定する識別器であるDiscriminator

    (※ニューラルネットは識別問題を解くのが得意です)

    Generator:画像生成器

    「Generator」のネットワーク構造は、以下の画像のようになっています。

    ノイズから顔画像を生成します。

    Discriminator:本物or偽物かの識別機

    「Discriminator」のネットワーク構造は、このようになっています。

    画像を入力として、その画像が「本物」か「偽物」かどうかを判定します。

    ノイズについて

    ノイズは、いわゆる正規分布からサンプリングされます。

    この分布の次元は、一般に画像空間よりも低次元です。

    (※画像空間は、縦×横のピクセル数だけ次元が存在します)

    2次元で考えた場合

    例えば、2次元で考えたとしましょう。

    そうしたとき上の図のように、「二次元上の点」が「ひとつの画像」を表すように画像空間を埋め込みます

    それぞれの青丸の点は、正規分布からサンプリングされた点です。

    このように、「Generator」の入力部のノイズは二次元上の点を表しています。(今回の場合)

    そして出力は、その二次元上の点に埋め込まれた画像空間上の画像を生成することになります。

    本来のノイズは、もっと高次元です。

    さすがに、顔写真を二次元空間に落とし込むのは無理があります。

    GANの損失関数

    それでは、まずは学習の話です。

    「Discriminator」の学習では、本物を本物、偽物を偽物と見分ける能力が必要です

    「Generator」から生成された画像に対しては、偽物ラベルの「0」を出力

    データセットの画像に対しては、本物ラベルの「1」を出力するようにします。

    こうすることで「Discriminator」は、本物画像と偽物画像それぞれの特徴を学習し、本物か偽物かを見分ける能力を伸ばしていきます。

    特徴を学習するための損失関数

    そのため。損失関数は以下のようになります。

    12(D(x)1)2+12(D(G(z)))2\frac{1}{2}(D(x)-1)^2+\frac{1}{2}(D(G(z)))^2

    xx本物の顔画像
    zzノイズ
    DDDiscriminatorDiscriminator
    GGGeneratorGenerator

    そして「Generator」の学習では、偽物画像を本物だと騙したいため、「Generator」から出力された画像を入力として「Discriminator」に入れた時の出力が「1」に近づくようにします。

    偽物を本物と思わせるための損失関数

    そのため、損失関数は以下のようになります。

    12(D(G(z))1)2\frac{1}{2}(D(G(z))-1)^2

    xx本物の顔画像
    zzノイズ
    DDDiscriminatorDiscriminator
    GGGeneratorGenerator

    このようにして「Generator」は、本物画像と偽物画像の特徴を学習した「Discriminator」を、さらに騙すように偽物画像をより本物画像へと近づけていきます

    これが「Adversarial(敵対的)」という由来です。

    GANの実験結果

    では、実際にGANの学習を行ってみます!

    そして、epoch 1~14までの結果を見ていきましょう!

    使ったデータセットは、以下のリンクのzipフォルダを解凍した画像群です。

    かなり枚数が多いので、その一部の20万枚だけ使いました。

    【Celeb-aデータセット】
    https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg

    発生させるノイズは固定で、どのように顔が作られていくかを見ることができます。

    epoch = 1

    epoch = 2(ぼんやりと顔の形が)

    epoch = 3(水彩画みたいな顔が出来始める)

    epoch = 4(サンダーバードみたいな顔が出来はじめる)

    epoch = 5(あまり変化なし、大体みんな同じ顔)

    epoch = 6(輪郭が綺麗に)

    epoch = 7(ちらほらいい結果が)

    epoch = 8(割愛)

    epoch = 9

    epoch = 10

    epoch = 11

    epoch = 12

    epoch = 13(斜め顔は苦手っぽい?)

    epoch = 14

    結果に対する感想

    顔画像の生成が出来ましたが、なにより生成データの多様性があるのが素晴らしいと思いました。

    笑っている人もいれば、口を閉じている人もいるし、顔の向きも特徴付けられています。

    それに肌の色なんかも、よく再現できているのではないでしょうか。

    GAN は学習自体が難しいのですが、たった数epochで成果が目で見えるので、総合的にみると楽になるな感じました。

    そして、学習は「データ数が命」ということに気付かされました。

    実験に用いたソースコード(Pytorch)

    では最後に、実験に用いたソースコードを紹介したいと思います。

    まず、Pytorch が入っていない方はコチラ!

    【Pytorch公式サイト】
    https://pytorch.org/

    こちらから、自分の環境にあったインストールコードを terminal に入力します。

    ソースコード

    1# -*- coding: utf-8 -*-
    2
    3
    4import os
    5import random
    6import numpy as np
    7import torch.nn as nn
    8import torch.optim as optim
    9import torch.utils.data
    10import torchvision.datasets as dset
    11import torchvision.transforms as transforms
    12import torchvision.utils as vutils
    13import matplotlib.pyplot as plt
    14import copy
    15from collections import OrderedDict
    16# Initial_setting
    17workers = 1
    18batch_size=64
    19nz = 100
    20nch_g = 64
    21nch_d = 64
    22n_epoch = 10000
    23lr = 0.001
    24beta1 = 0.5
    25outf = './result_lsgan'
    26display_interval = 100
    27save_fake_image_interval = 1500
    28plt.rcParams['figure.figsize'] = 10, 6
    29
    30
    31try:
    32        os.makedirs(outf, exist_ok=True)
    33except OSError as error:
    34        print(error)
    35        pass
    36 
    37random.seed(0)
    38np.random.seed(0)
    39torch.manual_seed(0)
    40 
    41 
    42def weights_init(m):
    43        classname = m.__class__.__name__
    44        if classname.find('Conv') != -1:                
    45            m.weight.data.normal_(0.0, 0.02)
    46            m.bias.data.fill_(0)
    47        elif classname.find('Linear') != -1:            
    48            m.weight.data.normal_(0.0, 0.02)
    49            m.bias.data.fill_(0)
    50        elif classname.find('BatchNorm') != -1:        
    51            m.weight.data.normal_(1.0, 0.02)
    52            m.bias.data.fill_(0)
    53 
    54 
    55class Generator(nn.Module):
    56        def __init__(self, nz=100, nch_g=64, nch=3):
    57            super(Generator, self).__init__()
    58            self.layers = nn.ModuleDict({
    59                'layer0': nn.Sequential(
    60                    nn.ConvTranspose2d(nz, nch_g * 8, 4, 1, 0),         
    61                    nn.BatchNorm2d(nch_g * 8),                          
    62                    nn.ReLU()                                        
    63                ),  # (100, 1, 1) -> (512, 4, 4)
    64                'layer1': nn.Sequential(
    65                    nn.ConvTranspose2d(nch_g * 8, nch_g * 4, 4, 2, 1),
    66                    nn.BatchNorm2d(nch_g * 4),
    67                    nn.ReLU()         
    68                ),  # (512, 4, 4) -> (256, 8, 8)
    69                'layer2': nn.Sequential(
    70                    nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 4, 2, 1),
    71                    nn.BatchNorm2d(nch_g * 2),
    72                    nn.ReLU()  
    73                ),  # (256, 8, 8) -> (128, 16, 16)
    74 
    75                'layer3': nn.Sequential(
    76                    nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
    77                    nn.BatchNorm2d(nch_g),
    78                    nn.ReLU()        
    79                ),  # (128, 16, 16) -> (64, 32, 32)
    80                'layer4': nn.Sequential(
    81                    nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
    82                    nn.Tanh()
    83                )   # (64, 32, 32) -> (3, 64, 64)
    84            })
    85 
    86        def forward(self, z):
    87            for layer in self.layers.values(): 
    88                z = layer(z)
    89            return z
    90 
    91 
    92class Discriminator(nn.Module):
    93        def __init__(self, nch=3, nch_d=64):
    94            super(Discriminator, self).__init__()
    95            self.layers = nn.ModuleDict({
    96                'layer0': nn.Sequential(
    97                    nn.Conv2d(nch, nch_d, 4, 2, 1), 
    98                    nn.BatchNorm2d(nch_d),
    99                    nn.LeakyReLU(negative_slope=0.2)        
    100                ),  # (3, 64, 64) -> (64, 32, 32)
    101                'layer1': nn.Sequential(
    102                    nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
    103                    nn.BatchNorm2d(nch_d*2),
    104                    nn.LeakyReLU(negative_slope=0.2)
    105                ),  # (64, 32, 32) -> (128, 16, 16)
    106                'layer2': nn.Sequential(
    107                    nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
    108                    nn.BatchNorm2d(nch_d*4),
    109                    nn.LeakyReLU(negative_slope=0.2)
    110                ),  # (128, 16, 16) -> (256, 8, 8)
    111                'layer3': nn.Sequential(
    112                    nn.Conv2d(nch_d * 4, nch_d * 8, 4, 2, 1),
    113                    nn.BatchNorm2d(nch_d*8),
    114                    nn.LeakyReLU(negative_slope=0.2)
    115                ),  # (256, 8, 8) -> (512, 4, 4)
    116                'layer4':nn.Sequential( nn.Conv2d(nch_d * 8, 1, 4, 1, 0)
    117                # (512, 4, 4) -> (1, 1, 1)
    118                #勾配消失を防ぐためにSigmoidは使わない
    119                )
    120            })
    121 
    122        def forward(self, x):
    123            for layer in self.layers.values(): 
    124                x = layer(x)
    125                x = x.squeeze()
    126               
    127            return x
    128 
    129        
    130        
    131
    132 
    133def main():
    134        #パスには入っている画像データセットの一個上の改装を
    135        #例:もし  your_home/Face_Datasets/Japanese/000.jpgのような階層になっていれば
    136        #root = your_home/Face_Datasetsとする
    137        dataset = dset.ImageFolder(root='C:/Users/User/Downloads/img_align_celeba/',
    138                                  transform=transforms.Compose([
    139                                  transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)),
    140                                  transforms.RandomHorizontalFlip(),
    141                                  transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    142                                  transforms.ToTensor(),
    143                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    144                              ])) 
    145  
    146        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
    147                                             shuffle=True, num_workers=int(workers))
    148 
    149        #GPUがあるならGPUデバイスを作動
    150        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    151        
    152        #Generatorを定義
    153        netG = Generator(nz=nz, nch_g=nch_g).to(device)
    154        #ネットワークパラメ-タをランダムに初期化
    155        netG.apply(weights_init)        
    156  
    157        #Discriminatorを定義
    158        netD = Discriminator(nch_d=nch_d).to(device)
    159        
    160        #ネットワークパラメ-タをランダムに初期化
    161        netD.apply(weights_init)
    162        
    163        #損失関数を二乗誤差に設定
    164        criterion = nn.MSELoss()  
    165 
    166        #それそれのパラメータ更新用のoptimizerを定義
    167        optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5) 
    168        optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5) 
    169 
    170        Loss_D_list, Loss_G_list = [], []
    171        
    172        fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)  # save_fake_image用ノイズ(固定)
    173        for epoch in range(n_epoch):
    174            
    175            for itr, data in enumerate(dataloader):
    176                
    177                
    178                real_image = data[0].to(device)   # 本物画像
    179                
    180                sample_size = real_image.size(0)  # 画像枚数
    181                noise = torch.randn(sample_size, nz, 1, 1, device=device)   # 入力ベクトル生成(正規分布ノイズ)           
    182                real_target = torch.full((sample_size,), random.uniform(1, 1), device=device)   # 本物ラベル
    183                fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device)   # 偽物ラベル
    184                
    185               
    186                #Generator→Discriminatorの順にパラメータ更新
    187                
    188                #---------  Update Generator   ----------
    189                netG.zero_grad()             
    190                fake_image = netG(noise) # Generatorから得られた偽画像
    191                output = netD(fake_image)   # 更新した Discriminatorで、偽物画像を判定
    192                errG = criterion(output,real_target)   # 偽物画像の判定結果と本物ラベルとの二乗誤差
    193                errG.backward(retain_graph = True)         # 誤差逆伝播
    194                D_G_z2 = output.mean().item()  # outputの平均 D_G_z2 を計算(後でログ出力に使用)
    195 
    196                optimizerG.step()  # Generatorのパラメータ更新
    197                
    198                #---------  Update Discriminaator   ----------
    199                netD.zero_grad() # 勾配の初期化  
    200                fake_image = netG(noise) # Generatorから得られた偽画像
    201                output = netD(real_image)   # Discriminatorが行った、本物画像の判定結果
    202                errD_real = criterion(output,real_target)  # 本物画像の判定結果と本物ラベルとの二乗誤差
    203                D_x = output.mean().item()  # outputの平均 D_x を計算(後でログ出力に使用)
    204             
    205                output = netD(fake_image.detach())  # Discriminatorが行った、偽物画像の判定結果
    206                
    207                errD_fake = criterion(output,fake_target)  # 偽物画像の判定結果と偽物画像との二乗誤差
    208                D_G_z1 = output.mean().item()  # outputの平均 D_G_z1 を計算(後でログ出力に使用)
    209 
    210                errD =  errD_real + errD_fake        # Discriminator 全体の損失
    211                errD.backward(retain_graph = True)        # 誤差逆伝播
    212                optimizerD.step()   # Discriminatoeのパラメーター更新
    213                
    214                   
    215                        
    216                
    217                
    218 
    219                #定期的に損失を表示
    220                if itr % 5 == 0:
    221                     print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
    222                          .format(epoch + 1, n_epoch,
    223                                  itr + 1, len(dataloader),
    224                                  errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    225                     
    226                     Loss_D_list.append(errD.item())
    227                     Loss_G_list.append(errG.item())
    228                                  
    229                #定期的に画像を保存
    230                if (itr + 1) % 50 == 0:
    231                    
    232                    fake_image = netG(fixed_noise) 
    233                    vutils.save_image(fake_image.detach(), './GAN/{:03d}random_{:03d}.png'.format(itr,epoch + 1),
    234                                      normalize=True, nrow=8)
    235         
    236            
    237            
    238 
    239        # plot graph
    240        plt.figure()        
    241        plt.plot(range(len(Loss_D_list)), Loss_D_list, color='blue', linestyle='-', label='Loss_D')
    242        plt.plot(range(len(Loss_G_list)), Loss_G_list, color='red', linestyle='-', label='Loss_G')
    243        plt.legend()
    244        plt.xlabel('iter (*100)')
    245        plt.ylabel('loss')
    246        plt.title('Loss_D and Loss_G')
    247        plt.grid()
    248        plt.savefig('Loss_graph.png')        
    249 
    250if __name__ == '__main__':
    251        main()

    さいごに

    今回は、画像生成モデルについてお話ししました。

    画像生成モデルは、「GAN」以外にも様々存在し、「GAN」自体の派生形も多く存在します。

    その研究は、盛んに行われており、多くの研究者が血肉を割いて実験しています。

    この分野での成長は凄まじく、今までの常識だった事がたった数ヶ月で覆される可能性もあります。

    そのため私たちエンジニアは、日々新しい知識を吸収し自分の手で動かし、使えるようにならなければなりません

    もしかしたら、次に技術革新を起こすのは、あなたかもしれません!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    featureImg2020.07.30Python 特集実装編※最新記事順Responder + Firestore でモダンかつサーバーレスなブログシステムを作ってみた!P...

    参考

    【GANの元論文】
    https://arxiv.org/abs/1406.2661

    【LSGAN】
    https://arxiv.org/abs/1611.04076

    広告メディア事業部

    広告メディア事業部

    おすすめ記事