
pix2pixで白黒画像をカラー変換する
2020.08.13
pix2pix とは?
pix2pix は、画像から画像への変換に用いられます。
例えば、「白黒画像からカラー画像への変換」であったり、「ラベル画像から元画像への変換」であったりなどです。

pix2pix のメリット
では、画像 to 画像の変換ができるようになると、どんなメリットがあるのでしょうか?
例えば、デジタルイラストは線画を書いた後、ペイントツールなどで色を塗ります。
その際、細かい調整が効かなかったりするため、色塗りが大変であることは言うまでもありません。
しかし、pix2pix では画像のセットさえあれば何でも変換できるため、「線画から色のついた画像を生成する」ことができるようになります。
これが実現すれば、デジタルイラストレーターの時間効率が飛躍的に上がるようになります。

こちらの記事もオススメ!
pix2pix の学習法
Generator の学習

通常の GAN に比べ、pix2pix では直接、生成画像と正解画像を近づける項を設けます。
画像を直接近づける際の誤差関数は、「L1Loss」をとります。
pix2pix の論文でも、「L1Loss」を取っています。
【Image-to-Image Translation with Conditional Adversarial Networks】
https://arxiv.org/pdf/1611.07004.pdf
Discriminator の学習

一方、「Discriminator」は、通常の GAN と変わりません。
正解画像は「1」に近づけ、不正解画像を「0」に近づけさせ、「偽物」か「本物」かを、Discriminator が見極められるようにします。
Generatorの構造
Generator は、「U-net 構造」をとります。
U-net 構造とは?
U-net は、以下のような構造をしています。
「画像 to 画像」の変換をする際、画像がボケることがあります。
オートエンコーダのようなボトルネックがある場合、その分の次元が圧縮されてしまうため、復元が難しくなるからです。
U-net もオートエンコーダ同様、「画像to画像」の変換です。
画像の大局的な部分と局所的な部分の両方を加味することで、高解像度な画像を生成することができます。
そのため、Generator には 「U-net」を採用します。
実装時の学習
まずは、Generator を学習します。
損失関数には、「L1Loss」を加えています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | real_image = data[0].to(device) # 本物画像 real_target = torch.full((sample_size,), random.uniform(1, 1), device=device) # 本物ラベル fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル #______カラー画像の白黒画像化______# _img = torch.Tensor(real_image.shape[0],1,real_image.shape[2],real_image.shape[3]) # _img[:,0,:,:] = (real_image[:,0,:,:]+real_image[:,1,:,:]+real_image[:,2,:,:])/3 gray =_img.to(device) #カラー画像を白黒画像にしたもの criterion = nn.BCELoss() criterion2 = nn.MSELoss() fake_color = gray2color(gray) #生成画像 output = D_color(fake_color) #生成画像に対するDiscriminatorの結果 adversarial_color_loss_fake = criterion(output,real_target) #Discriminatorの出力結果と正解ラベルとのBCELoss l1_loss = criterion2(real_image,fake_color) #生成結果と正解画像のL1Loss loss_g = adversarial_color_loss_fake +l1_loss*l1Loss_late #二つの損失をバランスを考えて加算 loss_g.backward(retain_graph = True) # 誤差逆伝播 optimizergray2color.step() # Generatorのパラメータ更新 |
Discriminator の学習
次に、Discriminator の学習を行います。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | fake_color = gray2color(gray) #Generatorの生成画像 output = D_color(fake_color) #生成画像に対するDiscriminatorの出力 adversarial_color_loss_fake = criterion(output,fake_target) #Discriminatorの出力結果と偽物ラベルとのBCELoss output = D_color(real_image) #正解画像に対するDiscriminatorの出力 adversarial_color_loss_real = criterion(output,real_target)#2 #Discriminatorの出力結果と本物ラベルとのBCELoss loss_d = adversarial_color_loss_fake+adversarial_color_loss_real #単純に加算 loss_d.backward(retain_graph = True) # 誤差逆伝播 optimizerD_color.step() # Discriminatorのパラメータ更新 |
学習時の工夫
学習が全く上手く行きませんでした。
原因の一つとして挙げられるのが、「バッチサイズ」です。
今までバッチサイズは大きければ良いと考えていましたが、どうやらバッチサイズが大きいほど過学習してしまうみたいです。
【バッチサイズは小さい方が良い】
http://marugari2.hatenablog.jp/entry/2017/12/04/174535
バッチサイズが大きいと、平均的に誤差を下げます。
そのため、二乗誤差を使った時と同様に、データセットに最も多い肌色を塗りやすくなり、他の色を使わなくなってしまいます。
そのため、学習時のバッチサイズは「1」にすることで、ようやく上手く行きました。
実験結果
さて今回は、「単純CNN で学習したモデル」と、「pix2pix で学習したモデル」それぞれで白黒画像のカラー化を行いました。
単純な CNN は、以下のリンクからコードを参照しています。
単純な CNN と pix2pix との比較
結果から見ると、「単純な CNNで 学習した場合」は人の肌色ばかり塗り、他の背景部分などは色がついていません。
二乗誤差を用いることで、平均的にピクセルレベルの誤差が小さくなれば良いため、データセットの中で最も多い「肌色」を塗っているからです。
しかし、pix2pix ではピクセル単位での誤差に加えて、画像全体としての評価も行っているため、より色合いが「鮮やか」になっていることがわかります。
損失に MSE を使った時との比較
現在、pix2pix の損失関数には「L1Loss」を使っています。
今度は、MSE つまり「L2Loss」を使ってみると結果がどうなるのか、比較してみたいと思います。
以下は、その実験結果です。
この画像から見てわかるように、色は確かに「GAN」により多様性が出てきました。
L1Loss を使った時より、L2 を使った時の方がボヤッとした色合いになっていることがわかります。
これも、損失に二乗誤差を用いていることによる弊害ですね。
GAN を用いても色のボヤッとした感じは残るようです。
というわけで、pix2pix の損失には「L1Loss」を使った方がいいという結論に至りました!
さいごに
「pix2pix」で、よりリアルな色合いのカラー画像の生成が可能になりました。
L1 と L2 で生成結果の差が出ることも確認できました。
「画像 to 画像」の変換に対して何らかのタスクを抱えている場合、pix2pix の実装は一考の価値ありです。
次回は、pix2pix の発展版である「cycleGAN」での画像変換を行いたいと思います。
(株)ライトコードは、WEB・アプリ・ゲーム開発に強い、「好きを仕事にするエンジニア集団」です。
機械学習でのシステム開発依頼・お見積もりはこちらまでお願いします。
また、機械学習系エンジニアを積極採用中です!詳しくはこちらをご覧ください。
※現在、多数のお問合せを頂いており、返信に、多少お時間を頂く場合がございます。
次回の記事はこちら
こちらの記事もオススメ!
実験に用いたソースコード
最後に、実験に用いたソースコードを紹介します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | import torch import torchvision import torchvision.datasets as dset from torch import nn from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image from mpl_toolkits.mplot3d import axes3d from torchvision.datasets import MNIST import os import math import pylab import matplotlib.pyplot as plt beta1 = 0.5 l1Loss_late =100 #L1LossとadversarilLossの重要度を決定する係数 num_epochs = 1 #エポック数 batch_size = 1 #バッチサイズ learning_rate = 1e-3 #学習率 train =True#学習を行うかどうかのフラグ pretrained =False#事前に学習したモデルがあるならそれを使う save_img = False #ネットワークによる生成画像を保存するかどうかのフラグ import random def to_img(x): x = 0.5 * (x + 1) x = x.clamp(0, 1) x = x.view(x.size(0), x.shape[1], x.shape[2],x.shape[3]) return x #データセットを調整する関数 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))]) #訓練用データセット #ここのパスは自分のGoogleDriveのパスに合うように変えてください dataset = dset.ImageFolder(root='./drive/My Drive/face/tmp3/', transform=transforms.Compose([ transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) #データセットをdataoaderで読み込み dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) #pix2pixのGenerator部分 class Generator(nn.Module): def __init__(self,nch,nch_d): super(Generator, self).__init__() nch_g = 64 #U-net部分 self.layer1 = self.conv_layer_forward(nch, nch_g , 3, 2, 1) self.layer2 = self.conv_layer_forward(nch_g , nch_g*2 , 3, 2, 1) self.layer3 = self.conv_layer_forward(nch_g*2 , nch_g*4 , 3, 2, 1) self.layer4= self.conv_layer_forward(nch_g*4 , nch_g*8 , 3, 2, 1) self.layer5= self.conv_layer_forward(nch_g*8 , nch_g*16 , 3, 2, 1) self.layer6= self.conv_layer_forward_image_size_1(nch_g*16 , nch_g*32 , 4, 1, 1) self.layer7= self.conv_layer_transpose(nch_g*32 , nch_g*16 , 4, 2, 1,False) self.layer8 = self.conv_layer_transpose(nch_g*32 , nch_g*8 , 4, 2, 1,False) self.layer9 = self.conv_layer_transpose(nch_g*16 , nch_g*4 , 4, 2, 1,False) self.layer10= self.conv_layer_transpose(nch_g*8 , nch_g*2 , 4, 2, 1,False) self.layer11= self.conv_layer_transpose(nch_g*4 , nch_g , 4, 2, 1,False) self.layer12 = self.conv_layer_transpose(nch_g*2 , nch_d , 4, 2, 1,True) def forward(self, z): z,z1 = self.convolution_forward(self.layer1,z) z,z2= self.convolution_forward(self.layer2,z) z,z3 = self.convolution_forward(self.layer3,z) z,z4 = self.convolution_forward(self.layer4,z) z,z5 = self.convolution_forward(self.layer5,z) z = self.convolution(self.layer6,z) z = self.convolution_deconv(self.layer7,z,z5) z = self.convolution_deconv(self.layer8,z,z4) z = self.convolution_deconv(self.layer9,z,z3) z = self.convolution_deconv(self.layer10,z,z2) z = self.convolution_deconv(self.layer11,z,z1) z = self.convolution(self.layer12,z) return z def convolution(self,layer_i,z): for layer in layer_i.values(): z = layer(z) return z def conv_layer_forward(self,input,out,kernel_size,stride,padding): return nn.ModuleDict({ 'layer0': nn.Sequential( nn.Conv2d(input,out,kernel_size,stride,padding), nn.BatchNorm2d(out), nn.ReLU() ), }) def conv_layer_forward_image_size_1(self,input,out,kernel_size,stride,padding): return nn.ModuleDict({ 'layer0': nn.Sequential( nn.Conv2d(input,out,kernel_size,stride,padding), nn.ReLU() ), }) def conv_layer_transpose(self,input,out,kernel_size,stride,padding,is_last): if is_last == True: return nn.ModuleDict({ 'layer0': nn.Sequential( nn.ConvTranspose2d(input , out , kernel_size, stride, padding), nn.Tanh() ), }) else : return nn.ModuleDict({ 'layer0': nn.Sequential( nn.ConvTranspose2d(input , out , kernel_size, stride, padding), nn.BatchNorm2d(out), nn.ReLU() ), }) def convolution_forward(self,layer,z): z = self.convolution(layer,z) z_copy = z return z,z_copy def convolution_deconv(self,layer,z,z_copy): z = self.convolution(layer,z) z = torch.cat([z,z_copy],dim = 1) return z class Discriminator(nn.Module): #Dicriminator部分 def __init__(self, nch=3, nch_d=64): super(Discriminator, self).__init__() self.layer1 = self.conv_layer(nch, nch_d, 4, 2, 1,False) self.layer2 = self.conv_layer(nch_d, nch_d * 2, 4, 2, 1,False) self.layer3 = self.conv_layer(nch_d * 2, nch_d * 4, 4, 2, 1,False) self.layer4 = self.conv_layer(nch_d * 4, nch_d * 8, 4, 2, 1,False) self.layer5 = self.conv_layer(nch_d * 8, 1, 4, 1,0,True) def conv_layer(self,input,out,kernel_size,stride,padding,is_last): if is_last == True: return nn.ModuleDict({ 'layer0': nn.Sequential( nn.Conv2d(input , out , kernel_size, stride, padding), nn.Tanh() ), }) else : return nn.ModuleDict({ 'layer0': nn.Sequential( nn.Conv2d(input , out , kernel_size, stride, padding), nn.BatchNorm2d(out), nn.ReLU() ), }) def convolution(self,layer_i,z): for layer in layer_i.values(): z = layer(z) return z def forward(self, x): x = self.convolution(self.layer1,x) x = self.convolution(self.layer2,x) x = self.convolution(self.layer3,x) x = self.convolution(self.layer4,x) x = self.convolution(self.layer5,x) return x def main(): #もしGPUがあるならGPUを使用してないならCPUを使用 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #ネットワークを呼び出し gray2color = Generator(1,3).to(device) #事前に学習しているモデルがあるならそれを読み込む #ここのパスは自分のGoogleDriveパスに合うように変えてください #./drive/My Drive/までは変えなくてできます if pretrained: param = torch.load('./drive/My Drive/gray2color.pth') gray2color.load_state_dict(param) D_color = Discriminator(nch=3,nch_d=64).to(device) if pretrained: param = torch.load('./drive/My Drive/D_color.pth') D_color.load_state_dict(param) #誤差関数には二乗誤差を使用 criterion = nn.BCELoss() criterion2 = nn.L1Loss() #更新式はAdamを適用 optimizerD_color = torch.optim.Adam(D_color.parameters(), lr=learning_rate, betas=(beta1, 0.999), weight_decay=1e-5) optimizergray2color = torch.optim.Adam(gray2color.parameters(), lr=learning_rate, betas=(beta1, 0.999), weight_decay=1e-5) loss_train_list = [] loss_test_list= [] for epoch in range(num_epochs): print(epoch) i=0 for data in dataloader: i=i+1 real_image = data[0].to(device) # 本物画像 sample_size = real_image.size(0) # 画像枚数 real_target = torch.full((sample_size,), random.uniform(1, 1), device=device) # 本物ラベル fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル #_imgはカラー画像をモノクロに変換した画像 _img = torch.Tensor(real_image.shape[0],1,real_image.shape[2],real_image.shape[3]) _img[:,0,:,:] = (real_image[:,0,:,:]+real_image[:,1,:,:]+real_image[:,2,:,:])/3 #_img = (real_image[:,:,::2, ::2] +real_image[:,:,1::2, ::2] + real_image[:,:,::2, 1::2] +real_image[:,:,1::2, 1::2])/4 gray =_img.to(device) #--------Generatorの学習------- #勾配情報の初期化 gray2color.zero_grad() D_color.zero_grad() fake_color = gray2color(gray) #生成画像 output = D_color(fake_color) #生成画像に対するDiscriminatorの結果 adversarial_color_loss_fake = criterion(output,real_target) #Discriminatorの出力結果と正解ラベルとのBCELoss l1_loss = criterion2(real_image,fake_color) #生成結果と正解画像のL1Loss loss_g = adversarial_color_loss_fake +l1_loss*l1Loss_late #二つの損失をバランスを考えて加算 loss_g.backward(retain_graph = True) # 誤差逆伝播 optimizergray2color.step() # Generatorのパラメータ更新 #------Discriminatorの学習------- #勾配情報の初期化 gray2color.zero_grad() D_color.zero_grad() fake_color = gray2color(gray)#生成画像 output = D_color(fake_color) #生成画像に対するDiscriminatorの出力 adversarial_color_loss_fake = criterion(output,fake_target) #Discriminatorの出力結果と偽物ラベルとのBCELoss output = D_color(real_image) #正解画像に対するDiscriminatorの出力 adversarial_color_loss_real = criterion(output,real_target)#2 #Discriminatorの出力結果と本物ラベルとのBCELoss loss_d = adversarial_color_loss_fake+adversarial_color_loss_real #単純に加算 loss_d.backward(retain_graph = True) # 誤差逆伝播 optimizerD_color.step() # Discriminatorのパラメータ更新 if i % 100==0: if save_img == True: value = int(math.sqrt(batch_size)) pic = to_img(gray.cpu().data) pic = torchvision.utils.make_grid(pic,nrow = value) save_image(pic, './mono_image_{}.png'.format(i)) #白黒画像を保存 pic = to_img(fake_color.cpu().data) pic = torchvision.utils.make_grid(pic,nrow = value) save_image(pic, './fake_image_{}.png'.format(i)) #生成画像を保存 print(i, len(dataloader),"g",loss_g,"L1Loss",l1_loss,"d",loss_d) if train == True: #モデルを保存 torch.save(gray2color.state_dict(), './drive/My Drive/gray2color.pth') torch.save(D_color.state_dict(), './drive/My Drive/D_color.pth') #ここのパスは自分のGoogleDriveのパスに合うように変えてください if __name__ == '__main__': main() |
ライトコードよりお知らせ






一緒に働いてくれる仲間を募集しております!
ライトコードでは、仲間を募集しております!
当社のモットーは「好きなことを仕事にするエンジニア集団」「エンジニアによるエンジニアのための会社」。エンジニアであるあなたの「やってみたいこと」を全力で応援する会社です。
また、ライトコードは現在、急成長中!だからこそ、あなたにお任せしたいやりがいのあるお仕事は沢山あります。「コアメンバー」として活躍してくれる、あなたからのご応募をお待ちしております!
なお、ご応募の前に、「話しだけ聞いてみたい」「社内の雰囲気を知りたい」という方はこちらをご覧ください。
ライトコードでは一緒に働いていただける方を募集しております!
採用情報はこちら書いた人はこんな人

IT技術2021.04.16【第1回】Djangoで日記アプリを作ろう~環境構築編~
IT技術2021.03.2910月20日メジャーアップデート!「Node.js v15」の新機能とは?
IT技術2021.03.02TypeScriptの型を問題形式で学べる「type-challenges」とは?
IT技術2021.03.01シスコルータのコンフィグ作成をPythonで自動化してみた!