• トップ
  • ブログ一覧
  • pix2pixで白黒画像をカラー変換する
  • pix2pixで白黒画像をカラー変換する

    広告メディア事業部広告メディア事業部
    2020.06.30

    IT技術

    pix2pix とは?

    pix2pix は、画像から画像への変換に用いられます。

    例えば、「白黒画像からカラー画像への変換」であったり、「ラベル画像から元画像への変換」であったりなどです。

    pix2pix のメリット

    では、画像 to 画像の変換ができるようになると、どんなメリットがあるのでしょうか?

    例えば、デジタルイラストは線画を書いた後、ペイントツールなどで色を塗ります。

    その際、細かい調整が効かなかったりするため、色塗りが大変であることは言うまでもありません。

    しかし、pix2pix では画像のセットさえあれば何でも変換できるため、「線画から色のついた画像を生成する」ことができるようになります。

    これが実現すれば、デジタルイラストレーターの時間効率が飛躍的に上がるようになります。

    pix2pix の学習法

    Generator の学習

    通常の GAN に比べ、pix2pix では直接、生成画像と正解画像を近づける項を設けます。

    画像を直接近づける際の誤差関数は、「L1Loss」をとります。

    pix2pix の論文でも、「L1Loss」を取っています。

    【Image-to-Image Translation with Conditional Adversarial Networks】
    https://arxiv.org/pdf/1611.07004.pdf

    Discriminator の学習

    一方、「Discriminator」は、通常の GAN と変わりません。

    正解画像は「1」に近づけ、不正解画像を「0」に近づけさせ、「偽物」か「本物」かを、Discriminator が見極められるようにします。

    Generatorの構造

    Generator は、「U-net 構造」をとります。

    U-net 構造とは?

    U-net は、以下のような構造をしています。

    「画像 to 画像」の変換をする際、画像がボケることがあります。

    オートエンコーダのようなボトルネックがある場合、その分の次元が圧縮されてしまうため、復元が難しくなるからです。

    U-net もオートエンコーダ同様、「画像to画像」の変換です。

    画像の大局的な部分と局所的な部分の両方を加味することで、高解像度な画像を生成することができます。

    そのため、Generator には 「U-net」を採用します。

    実装時の学習

    まずは、Generator を学習します。

    損失関数には、「L1Loss」を加えています。

    1real_image = data[0].to(device)   # 本物画像
    2real_target = torch.full((sample_size,), random.uniform(1, 1), device=device)   # 本物ラベル
    3fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル
    4
    5#______カラー画像の白黒画像化______#
    6_img = torch.Tensor(real_image.shape[0],1,real_image.shape[2],real_image.shape[3]) #
    7_img[:,0,:,:] = (real_image[:,0,:,:]+real_image[:,1,:,:]+real_image[:,2,:,:])/3 
    8gray  =_img.to(device) #カラー画像を白黒画像にしたもの
    9
    10criterion = nn.BCELoss()
    11criterion2 = nn.MSELoss()
    12
    13fake_color = gray2color(gray) #生成画像
    14
    15output = D_color(fake_color) #生成画像に対するDiscriminatorの結果
    16
    17adversarial_color_loss_fake = criterion(output,real_target) #Discriminatorの出力結果と正解ラベルとのBCELoss
    18
    19l1_loss = criterion2(real_image,fake_color) #生成結果と正解画像のL1Loss
    20
    21loss_g = adversarial_color_loss_fake +l1_loss*l1Loss_late #二つの損失をバランスを考えて加算
    22
    23loss_g.backward(retain_graph = True) # 誤差逆伝播
    24
    25optimizergray2color.step()  # Generatorのパラメータ更新

    Discriminator の学習

    次に、Discriminator の学習を行います。

    1fake_color = gray2color(gray) #Generatorの生成画像
    2            
    3output = D_color(fake_color) #生成画像に対するDiscriminatorの出力
    4
    5adversarial_color_loss_fake = criterion(output,fake_target) #Discriminatorの出力結果と偽物ラベルとのBCELoss
    6            
    7            
    8output = D_color(real_image) #正解画像に対するDiscriminatorの出力
    9adversarial_color_loss_real = criterion(output,real_target)#2  #Discriminatorの出力結果と本物ラベルとのBCELoss
    10   
    11            
    12loss_d = adversarial_color_loss_fake+adversarial_color_loss_real #単純に加算
    13loss_d.backward(retain_graph = True) # 誤差逆伝播
    14optimizerD_color.step()  # Discriminatorのパラメータ更新

    学習時の工夫

    学習が全く上手く行きませんでした。

    原因の一つとして挙げられるのが、「バッチサイズ」です。

    今までバッチサイズは大きければ良いと考えていましたが、どうやらバッチサイズが大きいほど過学習してしまうみたいです。

    【バッチサイズは小さい方が良い】
    http://marugari2.hatenablog.jp/entry/2017/12/04/174535

    バッチサイズが大きいと、平均的に誤差を下げます。

    そのため、二乗誤差を使った時と同様に、データセットに最も多い肌色を塗りやすくなり、他の色を使わなくなってしまいます。

    そのため、学習時のバッチサイズは「1」にすることで、ようやく上手く行きました。

    実験結果

    さて今回は、「単純CNN で学習したモデル」と、「pix2pix で学習したモデル」それぞれで白黒画像のカラー化を行いました。

    単純な CNN は、以下のリンクからコードを参照しています。

    featureImg2020.05.21白黒画像を 畳み込みニューラルネットワーク(CNN)を用いてカラー化する畳み込みニューラルネットワーク(CNN)で白黒画像をカラー化しようカメラができた当初、撮影された画像は、今のようなカラ...

    単純な CNN と pix2pix との比較

    結果から見ると、「単純な CNNで 学習した場合」は人の肌色ばかり塗り、他の背景部分などは色がついていません。

    二乗誤差を用いることで、平均的にピクセルレベルの誤差が小さくなれば良いため、データセットの中で最も多い「肌色」を塗っているからです。

    しかし、pix2pix ではピクセル単位での誤差に加えて、画像全体としての評価も行っているため、より色合いが「鮮やか」になっていることがわかります。

    損失に MSE を使った時との比較

    現在、pix2pix の損失関数には「L1Loss」を使っています。

    今度は、MSE つまり「L2Loss」を使ってみると結果がどうなるのか、比較してみたいと思います。

    以下は、その実験結果です。

    この画像から見てわかるように、色は確かに「GAN」により多様性が出てきました。

    L1Loss を使った時より、L2 を使った時の方がボヤッとした色合いになっていることがわかります。

    これも、損失に二乗誤差を用いていることによる弊害ですね。

    GAN を用いても色のボヤッとした感じは残るようです。

    というわけで、pix2pix の損失には「L1Loss」を使った方がいいという結論に至りました!

    さいごに

    「pix2pix」で、よりリアルな色合いのカラー画像の生成が可能になりました。

    L1 と L2 で生成結果の差が出ることも確認できました。

    「画像 to 画像」の変換に対して何らかのタスクを抱えている場合、pix2pix の実装は一考の価値ありです。

    次回は、pix2pix の発展版である「cycleGAN」での画像変換を行いたいと思います。

    次回の記事はこちら

    featureImg2020.08.07cycleGANで男顔⇄女顔への変換を可能してみた!cycleGANとは?「cycleGAN」は pix2pix と違い、ペア画像を必要としない、より画期的な「画像 to...

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    実験に用いたソースコード

    最後に、実験に用いたソースコードを紹介します。

    1import torch
    2import torchvision
    3
    4import torchvision.datasets as dset
    5from torch import nn
    6from torch.autograd import Variable
    7from torch.utils.data import DataLoader
    8from torchvision import transforms
    9from torchvision.utils import save_image
    10from mpl_toolkits.mplot3d import axes3d
    11from torchvision.datasets import MNIST
    12import os
    13import math
    14import pylab
    15import matplotlib.pyplot as plt
    16
    17beta1 = 0.5
    18l1Loss_late  =100 #L1LossとadversarilLossの重要度を決定する係数
    19num_epochs = 1 #エポック数
    20batch_size = 1 #バッチサイズ
    21learning_rate = 1e-3 #学習率
    22train =True#学習を行うかどうかのフラグ
    23pretrained =False#事前に学習したモデルがあるならそれを使う
    24save_img = False #ネットワークによる生成画像を保存するかどうかのフラグ
    25
    26import random
    27def to_img(x):
    28    x = 0.5 * (x + 1)
    29    x = x.clamp(0, 1)
    30    x = x.view(x.size(0), x.shape[1], x.shape[2],x.shape[3])
    31    return x
    32
    33
    34#データセットを調整する関数
    35transform = transforms.Compose(
    36    [transforms.ToTensor(),
    37     transforms.Normalize((0.5, ), (0.5, ))])
    38   
    39#訓練用データセット
    40#ここのパスは自分のGoogleDriveのパスに合うように変えてください
    41dataset = dset.ImageFolder(root='./drive/My Drive/face/tmp3/',
    42                              transform=transforms.Compose([
    43                              transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)),
    44                              transforms.RandomHorizontalFlip(),
    45                              transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    46                              transforms.ToTensor(),
    47                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    48                          ])) 
    49
    50#データセットをdataoaderで読み込み
    51dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    52
    53#pix2pixのGenerator部分
    54class Generator(nn.Module):
    55    def __init__(self,nch,nch_d):
    56        super(Generator, self).__init__()
    57        nch_g = 64
    58        #U-net部分
    59        self.layer1 = self.conv_layer_forward(nch, nch_g , 3, 2, 1)
    60        self.layer2 = self.conv_layer_forward(nch_g , nch_g*2 , 3, 2, 1)
    61        self.layer3 = self.conv_layer_forward(nch_g*2 , nch_g*4 , 3, 2, 1)
    62        self.layer4= self.conv_layer_forward(nch_g*4 , nch_g*8 , 3, 2, 1)
    63        self.layer5= self.conv_layer_forward(nch_g*8 , nch_g*16 , 3, 2, 1)
    64        self.layer6= self.conv_layer_forward_image_size_1(nch_g*16 , nch_g*32 , 4, 1, 1)
    65        self.layer7= self.conv_layer_transpose(nch_g*32 , nch_g*16 , 4, 2, 1,False)
    66        self.layer8 = self.conv_layer_transpose(nch_g*32 , nch_g*8 , 4, 2, 1,False)
    67        self.layer9 = self.conv_layer_transpose(nch_g*16 , nch_g*4 , 4, 2, 1,False)
    68        self.layer10= self.conv_layer_transpose(nch_g*8 , nch_g*2 , 4, 2, 1,False)
    69        self.layer11= self.conv_layer_transpose(nch_g*4 , nch_g , 4, 2, 1,False)
    70        self.layer12 = self.conv_layer_transpose(nch_g*2 , nch_d , 4, 2, 1,True)
    71
    72    def forward(self, z):
    73        z,z1 = self.convolution_forward(self.layer1,z)
    74        z,z2= self.convolution_forward(self.layer2,z)
    75        z,z3 = self.convolution_forward(self.layer3,z)
    76        z,z4 = self.convolution_forward(self.layer4,z)
    77        z,z5 = self.convolution_forward(self.layer5,z)
    78        z = self.convolution(self.layer6,z)
    79        z = self.convolution_deconv(self.layer7,z,z5)
    80        z = self.convolution_deconv(self.layer8,z,z4)
    81        z = self.convolution_deconv(self.layer9,z,z3)
    82        z = self.convolution_deconv(self.layer10,z,z2)
    83        z = self.convolution_deconv(self.layer11,z,z1)
    84        z = self.convolution(self.layer12,z)
    85        return z
    86
    87    def convolution(self,layer_i,z):
    88      for layer in layer_i.values(): 
    89            z = layer(z)
    90      return z
    91    
    92    def conv_layer_forward(self,input,out,kernel_size,stride,padding):
    93        return nn.ModuleDict({
    94              'layer0': nn.Sequential(
    95                  nn.Conv2d(input,out,kernel_size,stride,padding),
    96                  nn.BatchNorm2d(out),
    97                  nn.ReLU()  
    98                  ),  
    99              })
    100        
    101    def conv_layer_forward_image_size_1(self,input,out,kernel_size,stride,padding):
    102        return nn.ModuleDict({
    103              'layer0': nn.Sequential(
    104                  nn.Conv2d(input,out,kernel_size,stride,padding),
    105                  nn.ReLU()  
    106                  ),  
    107              })
    108        
    109    def conv_layer_transpose(self,input,out,kernel_size,stride,padding,is_last):
    110      if is_last == True:
    111        return nn.ModuleDict({
    112              'layer0': nn.Sequential(
    113                  nn.ConvTranspose2d(input , out , kernel_size, stride, padding),
    114                  nn.Tanh()  
    115                  ),
    116              })
    117      else :
    118        return nn.ModuleDict({
    119               'layer0': nn.Sequential(
    120                  nn.ConvTranspose2d(input , out , kernel_size, stride, padding),
    121                  nn.BatchNorm2d(out),
    122                  nn.ReLU()  
    123                  ), 
    124              })
    125        
    126    def convolution_forward(self,layer,z):
    127        z = self.convolution(layer,z)
    128        z_copy = z
    129        return z,z_copy
    130    def convolution_deconv(self,layer,z,z_copy):
    131        z = self.convolution(layer,z)
    132        z = torch.cat([z,z_copy],dim = 1)
    133        return z
    134
    135 
    136class Discriminator(nn.Module):
    137  #Dicriminator部分
    138  def __init__(self, nch=3, nch_d=64):
    139     super(Discriminator, self).__init__()
    140     self.layer1 = self.conv_layer(nch, nch_d, 4, 2, 1,False)
    141     self.layer2 = self.conv_layer(nch_d, nch_d * 2, 4, 2, 1,False)
    142     self.layer3 = self.conv_layer(nch_d * 2, nch_d * 4, 4, 2, 1,False)
    143     self.layer4 = self.conv_layer(nch_d * 4, nch_d * 8, 4, 2, 1,False)
    144     self.layer5 = self.conv_layer(nch_d * 8, 1, 4, 1,0,True)
    145     
    146  def conv_layer(self,input,out,kernel_size,stride,padding,is_last):
    147      if is_last == True:
    148        return nn.ModuleDict({
    149              'layer0': nn.Sequential(
    150                  nn.Conv2d(input , out , kernel_size, stride, padding),
    151                  nn.Tanh()  
    152                  ),
    153              })
    154      else :
    155        return nn.ModuleDict({
    156               'layer0': nn.Sequential(
    157                  nn.Conv2d(input , out , kernel_size, stride, padding),
    158                  nn.BatchNorm2d(out),
    159                  nn.ReLU()  
    160                  ), 
    161              })
    162        
    163  def convolution(self,layer_i,z):
    164      for layer in layer_i.values(): 
    165            z = layer(z)
    166      return z
    167  def forward(self, x):
    168      x = self.convolution(self.layer1,x)
    169      x = self.convolution(self.layer2,x)
    170      x = self.convolution(self.layer3,x)
    171      x = self.convolution(self.layer4,x)
    172      x = self.convolution(self.layer5,x)
    173        
    174      return x
    175def main():
    176    #もしGPUがあるならGPUを使用してないならCPUを使用
    177    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    178   
    179    #ネットワークを呼び出し
    180    gray2color = Generator(1,3).to(device)
    181    
    182
    183    #事前に学習しているモデルがあるならそれを読み込む
    184    #ここのパスは自分のGoogleDriveパスに合うように変えてください
    185    #./drive/My Drive/までは変えなくてできます
    186
    187    if pretrained:
    188        param = torch.load('./drive/My Drive/gray2color.pth')
    189        gray2color.load_state_dict(param)
    190   
    191    D_color = Discriminator(nch=3,nch_d=64).to(device)
    192    if pretrained:
    193        param = torch.load('./drive/My Drive/D_color.pth')
    194        
    195        D_color.load_state_dict(param)
    196
    197    #誤差関数には二乗誤差を使用
    198    criterion = nn.BCELoss()
    199    criterion2 = nn.L1Loss()
    200    #更新式はAdamを適用
    201    
    202    optimizerD_color = torch.optim.Adam(D_color.parameters(), lr=learning_rate, betas=(beta1, 0.999), weight_decay=1e-5) 
    203    optimizergray2color = torch.optim.Adam(gray2color.parameters(), lr=learning_rate, betas=(beta1, 0.999), weight_decay=1e-5) 
    204  
    205    loss_train_list = []
    206    loss_test_list= []
    207    for epoch in range(num_epochs):
    208        print(epoch)
    209        i=0
    210        for data in dataloader:
    211            i=i+1
    212            real_image = data[0].to(device)   # 本物画像
    213            sample_size = real_image.size(0)  # 画像枚数
    214            real_target = torch.full((sample_size,), random.uniform(1, 1), device=device)   # 本物ラベル
    215            fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device)   # 偽物ラベル
    216            
    217            #_imgはカラー画像をモノクロに変換した画像
    218            _img = torch.Tensor(real_image.shape[0],1,real_image.shape[2],real_image.shape[3])
    219            _img[:,0,:,:] = (real_image[:,0,:,:]+real_image[:,1,:,:]+real_image[:,2,:,:])/3
    220            #_img = (real_image[:,:,::2, ::2] +real_image[:,:,1::2, ::2] + real_image[:,:,::2, 1::2] +real_image[:,:,1::2, 1::2])/4
    221            gray =_img.to(device)
    222
    223            #--------Generatorの学習-------
    224            
    225            #勾配情報の初期化
    226            gray2color.zero_grad() 
    227            D_color.zero_grad()
    228
    229            fake_color = gray2color(gray) #生成画像
    230 
    231            output = D_color(fake_color) #生成画像に対するDiscriminatorの結果
    232 
    233            adversarial_color_loss_fake = criterion(output,real_target) #Discriminatorの出力結果と正解ラベルとのBCELoss
    234 
    235            l1_loss = criterion2(real_image,fake_color) #生成結果と正解画像のL1Loss
    236 
    237            loss_g = adversarial_color_loss_fake +l1_loss*l1Loss_late #二つの損失をバランスを考えて加算
    238            
    239            loss_g.backward(retain_graph = True) # 誤差逆伝播
    240            optimizergray2color.step()  # Generatorのパラメータ更新
    241
    242            #------Discriminatorの学習-------
    243
    244            #勾配情報の初期化
    245            gray2color.zero_grad() 
    246            D_color.zero_grad()
    247
    248            fake_color = gray2color(gray)#生成画像
    249
    250            output = D_color(fake_color) #生成画像に対するDiscriminatorの出力
    251 
    252            adversarial_color_loss_fake = criterion(output,fake_target) #Discriminatorの出力結果と偽物ラベルとのBCELoss
    253            
    254            
    255            output = D_color(real_image) #正解画像に対するDiscriminatorの出力
    256            adversarial_color_loss_real = criterion(output,real_target)#2  #Discriminatorの出力結果と本物ラベルとのBCELoss
    257
    258            loss_d = adversarial_color_loss_fake+adversarial_color_loss_real #単純に加算
    259            loss_d.backward(retain_graph = True) # 誤差逆伝播
    260            optimizerD_color.step()  # Discriminatorのパラメータ更新
    261
    262            if i % 100==0:
    263              if save_img == True:
    264                value = int(math.sqrt(batch_size))
    265                pic = to_img(gray.cpu().data)
    266                pic = torchvision.utils.make_grid(pic,nrow = value)
    267                save_image(pic, './mono_image_{}.png'.format(i))  #白黒画像を保存
    268
    269                pic = to_img(fake_color.cpu().data)
    270                pic = torchvision.utils.make_grid(pic,nrow = value)
    271                save_image(pic, './fake_image_{}.png'.format(i))  #生成画像を保存
    272
    273              print(i, len(dataloader),"g",loss_g,"L1Loss",l1_loss,"d",loss_d)    
    274          
    275        if train == True:
    276                #モデルを保存
    277                torch.save(gray2color.state_dict(), './drive/My Drive/gray2color.pth')
    278                torch.save(D_color.state_dict(), './drive/My Drive/D_color.pth')
    279                #ここのパスは自分のGoogleDriveのパスに合うように変えてください
    280    
    281if __name__ == '__main__':
    282    main()
    広告メディア事業部

    広告メディア事業部

    おすすめ記事