pix2pixで白黒画像をカラー変換する
IT技術
pix2pix とは?
pix2pix は、画像から画像への変換に用いられます。
例えば、「白黒画像からカラー画像への変換」であったり、「ラベル画像から元画像への変換」であったりなどです。
pix2pix のメリット
では、画像 to 画像の変換ができるようになると、どんなメリットがあるのでしょうか?
例えば、デジタルイラストは線画を書いた後、ペイントツールなどで色を塗ります。
その際、細かい調整が効かなかったりするため、色塗りが大変であることは言うまでもありません。
しかし、pix2pix では画像のセットさえあれば何でも変換できるため、「線画から色のついた画像を生成する」ことができるようになります。
これが実現すれば、デジタルイラストレーターの時間効率が飛躍的に上がるようになります。
pix2pix の学習法
Generator の学習
通常の GAN に比べ、pix2pix では直接、生成画像と正解画像を近づける項を設けます。
画像を直接近づける際の誤差関数は、「L1Loss」をとります。
pix2pix の論文でも、「L1Loss」を取っています。
【Image-to-Image Translation with Conditional Adversarial Networks】
https://arxiv.org/pdf/1611.07004.pdf
Discriminator の学習
一方、「Discriminator」は、通常の GAN と変わりません。
正解画像は「1」に近づけ、不正解画像を「0」に近づけさせ、「偽物」か「本物」かを、Discriminator が見極められるようにします。
Generatorの構造
Generator は、「U-net 構造」をとります。
U-net 構造とは?
U-net は、以下のような構造をしています。
「画像 to 画像」の変換をする際、画像がボケることがあります。
オートエンコーダのようなボトルネックがある場合、その分の次元が圧縮されてしまうため、復元が難しくなるからです。
U-net もオートエンコーダ同様、「画像to画像」の変換です。
画像の大局的な部分と局所的な部分の両方を加味することで、高解像度な画像を生成することができます。
そのため、Generator には 「U-net」を採用します。
実装時の学習
まずは、Generator を学習します。
損失関数には、「L1Loss」を加えています。
1real_image = data[0].to(device) # 本物画像
2real_target = torch.full((sample_size,), random.uniform(1, 1), device=device) # 本物ラベル
3fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル
4
5#______カラー画像の白黒画像化______#
6_img = torch.Tensor(real_image.shape[0],1,real_image.shape[2],real_image.shape[3]) #
7_img[:,0,:,:] = (real_image[:,0,:,:]+real_image[:,1,:,:]+real_image[:,2,:,:])/3
8gray =_img.to(device) #カラー画像を白黒画像にしたもの
9
10criterion = nn.BCELoss()
11criterion2 = nn.MSELoss()
12
13fake_color = gray2color(gray) #生成画像
14
15output = D_color(fake_color) #生成画像に対するDiscriminatorの結果
16
17adversarial_color_loss_fake = criterion(output,real_target) #Discriminatorの出力結果と正解ラベルとのBCELoss
18
19l1_loss = criterion2(real_image,fake_color) #生成結果と正解画像のL1Loss
20
21loss_g = adversarial_color_loss_fake +l1_loss*l1Loss_late #二つの損失をバランスを考えて加算
22
23loss_g.backward(retain_graph = True) # 誤差逆伝播
24
25optimizergray2color.step() # Generatorのパラメータ更新
Discriminator の学習
次に、Discriminator の学習を行います。
1fake_color = gray2color(gray) #Generatorの生成画像
2
3output = D_color(fake_color) #生成画像に対するDiscriminatorの出力
4
5adversarial_color_loss_fake = criterion(output,fake_target) #Discriminatorの出力結果と偽物ラベルとのBCELoss
6
7
8output = D_color(real_image) #正解画像に対するDiscriminatorの出力
9adversarial_color_loss_real = criterion(output,real_target)#2 #Discriminatorの出力結果と本物ラベルとのBCELoss
10
11
12loss_d = adversarial_color_loss_fake+adversarial_color_loss_real #単純に加算
13loss_d.backward(retain_graph = True) # 誤差逆伝播
14optimizerD_color.step() # Discriminatorのパラメータ更新
学習時の工夫
学習が全く上手く行きませんでした。
原因の一つとして挙げられるのが、「バッチサイズ」です。
今までバッチサイズは大きければ良いと考えていましたが、どうやらバッチサイズが大きいほど過学習してしまうみたいです。
【バッチサイズは小さい方が良い】
http://marugari2.hatenablog.jp/entry/2017/12/04/174535
バッチサイズが大きいと、平均的に誤差を下げます。
そのため、二乗誤差を使った時と同様に、データセットに最も多い肌色を塗りやすくなり、他の色を使わなくなってしまいます。
そのため、学習時のバッチサイズは「1」にすることで、ようやく上手く行きました。
実験結果
さて今回は、「単純CNN で学習したモデル」と、「pix2pix で学習したモデル」それぞれで白黒画像のカラー化を行いました。
単純な CNN は、以下のリンクからコードを参照しています。
単純な CNN と pix2pix との比較
結果から見ると、「単純な CNNで 学習した場合」は人の肌色ばかり塗り、他の背景部分などは色がついていません。
二乗誤差を用いることで、平均的にピクセルレベルの誤差が小さくなれば良いため、データセットの中で最も多い「肌色」を塗っているからです。
しかし、pix2pix ではピクセル単位での誤差に加えて、画像全体としての評価も行っているため、より色合いが「鮮やか」になっていることがわかります。
損失に MSE を使った時との比較
現在、pix2pix の損失関数には「L1Loss」を使っています。
今度は、MSE つまり「L2Loss」を使ってみると結果がどうなるのか、比較してみたいと思います。
以下は、その実験結果です。
この画像から見てわかるように、色は確かに「GAN」により多様性が出てきました。
L1Loss を使った時より、L2 を使った時の方がボヤッとした色合いになっていることがわかります。
これも、損失に二乗誤差を用いていることによる弊害ですね。
GAN を用いても色のボヤッとした感じは残るようです。
というわけで、pix2pix の損失には「L1Loss」を使った方がいいという結論に至りました!
さいごに
「pix2pix」で、よりリアルな色合いのカラー画像の生成が可能になりました。
L1 と L2 で生成結果の差が出ることも確認できました。
「画像 to 画像」の変換に対して何らかのタスクを抱えている場合、pix2pix の実装は一考の価値ありです。
次回は、pix2pix の発展版である「cycleGAN」での画像変換を行いたいと思います。
次回の記事はこちら
2020.08.07cycleGANで男顔⇄女顔への変換を可能してみた!cycleGANとは?「cycleGAN」は pix2pix と違い、ペア画像を必要としない、より画期的な「画像 to...
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
実験に用いたソースコード
最後に、実験に用いたソースコードを紹介します。
1import torch
2import torchvision
3
4import torchvision.datasets as dset
5from torch import nn
6from torch.autograd import Variable
7from torch.utils.data import DataLoader
8from torchvision import transforms
9from torchvision.utils import save_image
10from mpl_toolkits.mplot3d import axes3d
11from torchvision.datasets import MNIST
12import os
13import math
14import pylab
15import matplotlib.pyplot as plt
16
17beta1 = 0.5
18l1Loss_late =100 #L1LossとadversarilLossの重要度を決定する係数
19num_epochs = 1 #エポック数
20batch_size = 1 #バッチサイズ
21learning_rate = 1e-3 #学習率
22train =True#学習を行うかどうかのフラグ
23pretrained =False#事前に学習したモデルがあるならそれを使う
24save_img = False #ネットワークによる生成画像を保存するかどうかのフラグ
25
26import random
27def to_img(x):
28 x = 0.5 * (x + 1)
29 x = x.clamp(0, 1)
30 x = x.view(x.size(0), x.shape[1], x.shape[2],x.shape[3])
31 return x
32
33
34#データセットを調整する関数
35transform = transforms.Compose(
36 [transforms.ToTensor(),
37 transforms.Normalize((0.5, ), (0.5, ))])
38
39#訓練用データセット
40#ここのパスは自分のGoogleDriveのパスに合うように変えてください
41dataset = dset.ImageFolder(root='./drive/My Drive/face/tmp3/',
42 transform=transforms.Compose([
43 transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)),
44 transforms.RandomHorizontalFlip(),
45 transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
46 transforms.ToTensor(),
47 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
48 ]))
49
50#データセットをdataoaderで読み込み
51dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
52
53#pix2pixのGenerator部分
54class Generator(nn.Module):
55 def __init__(self,nch,nch_d):
56 super(Generator, self).__init__()
57 nch_g = 64
58 #U-net部分
59 self.layer1 = self.conv_layer_forward(nch, nch_g , 3, 2, 1)
60 self.layer2 = self.conv_layer_forward(nch_g , nch_g*2 , 3, 2, 1)
61 self.layer3 = self.conv_layer_forward(nch_g*2 , nch_g*4 , 3, 2, 1)
62 self.layer4= self.conv_layer_forward(nch_g*4 , nch_g*8 , 3, 2, 1)
63 self.layer5= self.conv_layer_forward(nch_g*8 , nch_g*16 , 3, 2, 1)
64 self.layer6= self.conv_layer_forward_image_size_1(nch_g*16 , nch_g*32 , 4, 1, 1)
65 self.layer7= self.conv_layer_transpose(nch_g*32 , nch_g*16 , 4, 2, 1,False)
66 self.layer8 = self.conv_layer_transpose(nch_g*32 , nch_g*8 , 4, 2, 1,False)
67 self.layer9 = self.conv_layer_transpose(nch_g*16 , nch_g*4 , 4, 2, 1,False)
68 self.layer10= self.conv_layer_transpose(nch_g*8 , nch_g*2 , 4, 2, 1,False)
69 self.layer11= self.conv_layer_transpose(nch_g*4 , nch_g , 4, 2, 1,False)
70 self.layer12 = self.conv_layer_transpose(nch_g*2 , nch_d , 4, 2, 1,True)
71
72 def forward(self, z):
73 z,z1 = self.convolution_forward(self.layer1,z)
74 z,z2= self.convolution_forward(self.layer2,z)
75 z,z3 = self.convolution_forward(self.layer3,z)
76 z,z4 = self.convolution_forward(self.layer4,z)
77 z,z5 = self.convolution_forward(self.layer5,z)
78 z = self.convolution(self.layer6,z)
79 z = self.convolution_deconv(self.layer7,z,z5)
80 z = self.convolution_deconv(self.layer8,z,z4)
81 z = self.convolution_deconv(self.layer9,z,z3)
82 z = self.convolution_deconv(self.layer10,z,z2)
83 z = self.convolution_deconv(self.layer11,z,z1)
84 z = self.convolution(self.layer12,z)
85 return z
86
87 def convolution(self,layer_i,z):
88 for layer in layer_i.values():
89 z = layer(z)
90 return z
91
92 def conv_layer_forward(self,input,out,kernel_size,stride,padding):
93 return nn.ModuleDict({
94 'layer0': nn.Sequential(
95 nn.Conv2d(input,out,kernel_size,stride,padding),
96 nn.BatchNorm2d(out),
97 nn.ReLU()
98 ),
99 })
100
101 def conv_layer_forward_image_size_1(self,input,out,kernel_size,stride,padding):
102 return nn.ModuleDict({
103 'layer0': nn.Sequential(
104 nn.Conv2d(input,out,kernel_size,stride,padding),
105 nn.ReLU()
106 ),
107 })
108
109 def conv_layer_transpose(self,input,out,kernel_size,stride,padding,is_last):
110 if is_last == True:
111 return nn.ModuleDict({
112 'layer0': nn.Sequential(
113 nn.ConvTranspose2d(input , out , kernel_size, stride, padding),
114 nn.Tanh()
115 ),
116 })
117 else :
118 return nn.ModuleDict({
119 'layer0': nn.Sequential(
120 nn.ConvTranspose2d(input , out , kernel_size, stride, padding),
121 nn.BatchNorm2d(out),
122 nn.ReLU()
123 ),
124 })
125
126 def convolution_forward(self,layer,z):
127 z = self.convolution(layer,z)
128 z_copy = z
129 return z,z_copy
130 def convolution_deconv(self,layer,z,z_copy):
131 z = self.convolution(layer,z)
132 z = torch.cat([z,z_copy],dim = 1)
133 return z
134
135
136class Discriminator(nn.Module):
137 #Dicriminator部分
138 def __init__(self, nch=3, nch_d=64):
139 super(Discriminator, self).__init__()
140 self.layer1 = self.conv_layer(nch, nch_d, 4, 2, 1,False)
141 self.layer2 = self.conv_layer(nch_d, nch_d * 2, 4, 2, 1,False)
142 self.layer3 = self.conv_layer(nch_d * 2, nch_d * 4, 4, 2, 1,False)
143 self.layer4 = self.conv_layer(nch_d * 4, nch_d * 8, 4, 2, 1,False)
144 self.layer5 = self.conv_layer(nch_d * 8, 1, 4, 1,0,True)
145
146 def conv_layer(self,input,out,kernel_size,stride,padding,is_last):
147 if is_last == True:
148 return nn.ModuleDict({
149 'layer0': nn.Sequential(
150 nn.Conv2d(input , out , kernel_size, stride, padding),
151 nn.Tanh()
152 ),
153 })
154 else :
155 return nn.ModuleDict({
156 'layer0': nn.Sequential(
157 nn.Conv2d(input , out , kernel_size, stride, padding),
158 nn.BatchNorm2d(out),
159 nn.ReLU()
160 ),
161 })
162
163 def convolution(self,layer_i,z):
164 for layer in layer_i.values():
165 z = layer(z)
166 return z
167 def forward(self, x):
168 x = self.convolution(self.layer1,x)
169 x = self.convolution(self.layer2,x)
170 x = self.convolution(self.layer3,x)
171 x = self.convolution(self.layer4,x)
172 x = self.convolution(self.layer5,x)
173
174 return x
175def main():
176 #もしGPUがあるならGPUを使用してないならCPUを使用
177 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
178
179 #ネットワークを呼び出し
180 gray2color = Generator(1,3).to(device)
181
182
183 #事前に学習しているモデルがあるならそれを読み込む
184 #ここのパスは自分のGoogleDriveパスに合うように変えてください
185 #./drive/My Drive/までは変えなくてできます
186
187 if pretrained:
188 param = torch.load('./drive/My Drive/gray2color.pth')
189 gray2color.load_state_dict(param)
190
191 D_color = Discriminator(nch=3,nch_d=64).to(device)
192 if pretrained:
193 param = torch.load('./drive/My Drive/D_color.pth')
194
195 D_color.load_state_dict(param)
196
197 #誤差関数には二乗誤差を使用
198 criterion = nn.BCELoss()
199 criterion2 = nn.L1Loss()
200 #更新式はAdamを適用
201
202 optimizerD_color = torch.optim.Adam(D_color.parameters(), lr=learning_rate, betas=(beta1, 0.999), weight_decay=1e-5)
203 optimizergray2color = torch.optim.Adam(gray2color.parameters(), lr=learning_rate, betas=(beta1, 0.999), weight_decay=1e-5)
204
205 loss_train_list = []
206 loss_test_list= []
207 for epoch in range(num_epochs):
208 print(epoch)
209 i=0
210 for data in dataloader:
211 i=i+1
212 real_image = data[0].to(device) # 本物画像
213 sample_size = real_image.size(0) # 画像枚数
214 real_target = torch.full((sample_size,), random.uniform(1, 1), device=device) # 本物ラベル
215 fake_target = torch.full((sample_size,), random.uniform(0, 0), device=device) # 偽物ラベル
216
217 #_imgはカラー画像をモノクロに変換した画像
218 _img = torch.Tensor(real_image.shape[0],1,real_image.shape[2],real_image.shape[3])
219 _img[:,0,:,:] = (real_image[:,0,:,:]+real_image[:,1,:,:]+real_image[:,2,:,:])/3
220 #_img = (real_image[:,:,::2, ::2] +real_image[:,:,1::2, ::2] + real_image[:,:,::2, 1::2] +real_image[:,:,1::2, 1::2])/4
221 gray =_img.to(device)
222
223 #--------Generatorの学習-------
224
225 #勾配情報の初期化
226 gray2color.zero_grad()
227 D_color.zero_grad()
228
229 fake_color = gray2color(gray) #生成画像
230
231 output = D_color(fake_color) #生成画像に対するDiscriminatorの結果
232
233 adversarial_color_loss_fake = criterion(output,real_target) #Discriminatorの出力結果と正解ラベルとのBCELoss
234
235 l1_loss = criterion2(real_image,fake_color) #生成結果と正解画像のL1Loss
236
237 loss_g = adversarial_color_loss_fake +l1_loss*l1Loss_late #二つの損失をバランスを考えて加算
238
239 loss_g.backward(retain_graph = True) # 誤差逆伝播
240 optimizergray2color.step() # Generatorのパラメータ更新
241
242 #------Discriminatorの学習-------
243
244 #勾配情報の初期化
245 gray2color.zero_grad()
246 D_color.zero_grad()
247
248 fake_color = gray2color(gray)#生成画像
249
250 output = D_color(fake_color) #生成画像に対するDiscriminatorの出力
251
252 adversarial_color_loss_fake = criterion(output,fake_target) #Discriminatorの出力結果と偽物ラベルとのBCELoss
253
254
255 output = D_color(real_image) #正解画像に対するDiscriminatorの出力
256 adversarial_color_loss_real = criterion(output,real_target)#2 #Discriminatorの出力結果と本物ラベルとのBCELoss
257
258 loss_d = adversarial_color_loss_fake+adversarial_color_loss_real #単純に加算
259 loss_d.backward(retain_graph = True) # 誤差逆伝播
260 optimizerD_color.step() # Discriminatorのパラメータ更新
261
262 if i % 100==0:
263 if save_img == True:
264 value = int(math.sqrt(batch_size))
265 pic = to_img(gray.cpu().data)
266 pic = torchvision.utils.make_grid(pic,nrow = value)
267 save_image(pic, './mono_image_{}.png'.format(i)) #白黒画像を保存
268
269 pic = to_img(fake_color.cpu().data)
270 pic = torchvision.utils.make_grid(pic,nrow = value)
271 save_image(pic, './fake_image_{}.png'.format(i)) #生成画像を保存
272
273 print(i, len(dataloader),"g",loss_g,"L1Loss",l1_loss,"d",loss_d)
274
275 if train == True:
276 #モデルを保存
277 torch.save(gray2color.state_dict(), './drive/My Drive/gray2color.pth')
278 torch.save(D_color.state_dict(), './drive/My Drive/D_color.pth')
279 #ここのパスは自分のGoogleDriveのパスに合うように変えてください
280
281if __name__ == '__main__':
282 main()
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit
おすすめ記事
immichを知ってほしい
2024.10.31