【機械学習】CNNで低解像度な画像を高解像度に変換してみる
IT技術
CNNで低解像度な画像を高解像度に変換してみたい!
前回は、「畳み込みニューラルネットワーク(Convolutional Neural Network:CNN)」を用いて、顔画像のカラー復元を行いました。
小さな画像サイズで実験を行いましたが、実際の場面では限られたリソース(計算資源)の中で、大きな画像サイズを取り扱う場面も多くあります。
そこで今回は、CNNを用いて、前回行ったカラー化と画像の高解像度化を組み合わせた実験をしてみたいと思います。
前回の記事はこちら
低解像度画像と高解像度画像
一般的にボケている画像のことを「低解像度な画像」と呼びます。
それに対し「高解像度な画像」とは、ボケの少ない画像のことをいいます。
そもそもボケというのは、画像のピクセル数が減少することで起こります。
低解像度化のイメージ
例えば、左下の画像のように適当なブロック分けを行い、その一ブロック内の全画素値の平均値を新たな画素値とすることを、画像全体に行うことでピクセル数を減少させることができます。
低解像度化は簡単
右の画像は、元画像に対してピクセル数が減少しているため「低解像度な画像」と言えます。
このように、「高解像度の画像」から「低解像度な画像」を作ることは比較的容易にできます。
低解像度画像から高解像度画像への変換
では逆のパターンで、「低解像度な画像」から「高解像度な画像」を生成することを考えてみましょう。
先ほどのブロックに分けた例で言うと、「低解像度な画像」の高解像度化は、一つのピクセルから複数のピクセルの値を決定するということが必要になります。
高解像度化のイメージ
「高解像度化」の作業は、下図のように解が一意に決まらず、様々なパターンが存在することが容易に想像できます。
高解像度化は難しい
言うなれば、先ほどの「高解像度な画像」もその解の一つであるものの、正解かどうかは分かりません。
そのため、「高解像度化」を人間が手探りで行うのは途方もなく難しい作業なのです。
そこで、あり得る様々なパターンから最も自然な画像になるパターンをニューラルネットワークに出力してもらうと言うのが、CNN で「高解像度化」をするモチベーションです。
高解像度化を用いるメリット
では、「高解像度化」をすることでどんなメリットがあるのでしょう。
機械学習のボトルネック
これは前回も話したとおり、機械学習の研究が進んだ現在、ボトルネックとなるのは計算資源です。
「高解像度な画像」を使って機械学習を行うことができればそれでいいのですが、そうはいきません。
ただでさえ制限があるメモリに、ニューラルネットワークのパラメータを多くメモリに保存しなければなりません。
そのため、画像サイズでその足りないメモリの帳尻合わせをする必要があるのです。
ボトルネック解消のために
そのため病理画像など、「高解像度な画像」に対しては、かなり小さい画像でしか学習ができません。
そんな中で低解像度で学習したものを「高解像度化」できれば、メモリの節約にもなり、学習結果もより向上するのではないかと言うのが私の考えるメリットです。
高解像度化のための学習モデル
では実際に、「学習モデル」について話していきます。
U-net のような構造にしました。
また、白黒画像のカラー化のときと同じように、深い層で浅い層の情報を渡し、細部の画像情報と全体的な情報を両方とも加味できるようにしました。
私の中でこれが一番損失関数が小さくなったので、このモデルを選びました。
カラー化と高解像度化の併用実験
検証用データに対する精度
「低解像度画像」に比べて生成画像はボケが小さい、つまりは「高解像度」になっていることがわかります。
低解像度画像
生成画像
私自身、かなり精度が良くてびっくりしています。
SRGAN を使わなくてもある程度は精度が出るようです。
白黒画像のカラー化との併用実験
次に、白黒画像の「カラー化」と画像の「高解像度化」を併用して実験してみました。
併用実験の方法
併用実験の方法としては、以下のように行いました。
- 64×64の白黒画像をカラー化する
- 変換した64×64のカラー画像を128×128に高解像度化
こうすることで、実際の実験環境ではメモリの影響で学習できなかった画像サイズでも、カラー画像の変換ができるようになります。
実験!
実際に、以下の変換画像の比較を行ってみましょう!
- 「64×64のサイズ」で学習したモデルで「128×128の画像」をカラー化した場合
- 「64×64の画像」をカラー化した後で「128×128に高解像度化」した場合
カラー化のみ
カラー化+高解像度化
実験結果
見てわかるように、「カラー化のみ」の結果では、色が正しく塗れていない部分があり、ムラがあります。
しかし、「カラー化+高解像度化」の結果は、前者と比較してムラが少なく、「64×64のサイズ」で学習したモデルが「128×128の画像サイズ」にも適用できていることがわかります。
このように、高解像度化は他の実験に併用することも有効であると言えます。
様々なタスクへの応用
それは白黒画像のカラー化だけでなく、画像のセグメンテーションなど様々なタスクに応用することができます。
学習時の画像サイズに限界を感じている機械学習エンジニアの方は、ぜひ参考にしてみてくださ!
さいごに ~次回予告~
さて、画像の「カラー化」と「高解像度化」と行ってきました。
しかし、白黒画像を「カラー化」したモデルでは、よく見ると肌色を顔の近くに塗っているだけで、背景も肌色に塗られていることがわかります。
正直言って、自然な画像であるとは言えません。
さらに、「カラー化」との併用実験では、画像の「高解像度化」も、実際の画像の解像度に比べてボケていることがわかります。
これは、「損失関数による問題」です。
ピクセル単位での二乗誤差をとることで、画像が自然なパターンにならずに全体的にボヤッとした見た目になるからです。
そこで次回からは、ピクセル単位での誤差に加えて、「GAN」による「Adversarial Loss」を加えることで、画像のボケ除去や色の多様性の追加を行っていきたいと思います!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ソースコード
1import torch
2import torchvision
3
4import torchvision.datasets as dset
5from torch import nn
6from torch.autograd import Variable
7from torch.utils.data import DataLoader
8from torchvision import transforms
9from torchvision.utils import save_image
10from mpl_toolkits.mplot3d import axes3d
11from torchvision.datasets import MNIST
12import os
13import math
14import pylab
15import matplotlib.pyplot as plt
16
17
18
19num_epochs = 1 #エポック数
20batch_size = 100 #バッチサイズ
21learning_rate = 1e-3 #学習率
22
23train = True#学習を行うかどうかのフラグ
24pretrained =False#事前に学習したモデルがあるならそれを使う
25save_img = True #ネットワークによる生成画像を保存するかどうのフラグ
26
27def to_img(x):
28 x = 0.5 * (x + 1)
29 x = x.clamp(0, 1)
30 x = x.view(x.size(0), 3, x.shape[2], x.shape[3])
31 return x
32def to_img_mono(x):
33 x = 0.5 * (x + 1)
34 x = x.clamp(0, 1)
35 x = x.view(x.size(0), 3,x.shape[2], x.shape[3])
36 return x
37
38#データセットを調整する関数
39transform = transforms.Compose(
40 [transforms.ToTensor(),
41 transforms.Normalize((0.5, ), (0.5, ))])
42
43#訓練用データセット
44dataset = dset.ImageFolder(root='./drive/My Drive/face/',
45 transform=transforms.Compose([
46 transforms.RandomResizedCrop(64, scale=(1.0, 1.0), ratio=(1., 1.)),
47 transforms.RandomHorizontalFlip(),
48 transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
49 transforms.ToTensor(),
50 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
51 ]))
52
53#データセットをdataoaderで読み込み
54dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
55
56#低解像度画像を高解像度化するニューラルネットワーク
57class SizeDecoder(nn.Module):
58 def __init__(self):
59 super(SizeDecoder, self).__init__()
60 nch_g = 64
61 self.layer1 = nn.ModuleDict({
62 'layer0': nn.Sequential(
63 nn.Conv2d(3, nch_g , 3, 2, 1),
64 nn.BatchNorm2d(nch_g),
65 nn.ReLU()
66 ), # (1, 64, 64) -> (64, 32, 32)
67
68
69 })
70
71 self.layer2 = nn.ModuleDict({
72
73 'layer0': nn.Sequential(
74 nn.Conv2d(nch_g , nch_g*2 , 3, 2, 1),
75 nn.BatchNorm2d(nch_g*2),
76 nn.ReLU()
77 ), # (64, 32, 32) -> (128, 16, 16)
78 })
79 self.layer3 = nn.ModuleDict({
80
81 'layer0': nn.Sequential(
82 nn.Conv2d(nch_g*2 , nch_g*4 , 3, 2, 1),
83 nn.BatchNorm2d(nch_g*4),
84 nn.ReLU()
85 ), # (128, 16, 16) -> (256, 8, 8)
86
87 })
88 self.layer4= nn.ModuleDict({
89 'layer0': nn.Sequential(
90 nn.Conv2d(nch_g*4 , nch_g*8 , 3, 2, 1),
91 nn.BatchNorm2d(nch_g*8),
92 nn.ReLU()
93 ), # (256, 8, 8) -> (512, 4, 4)
94
95 })
96
97
98 self.layer7 = nn.ModuleDict({
99 'layer0': nn.Sequential(
100 nn.ConvTranspose2d(nch_g*8 , nch_g*4 , 4, 2, 1),
101 nn.BatchNorm2d(nch_g*4),
102 nn.ReLU()
103 ), # (512, 4, 4) -> (256, 8, 8)
104 })
105 self.layer8 = nn.ModuleDict({
106 'layer0': nn.Sequential(
107 nn.ConvTranspose2d(nch_g*4 , nch_g*2 , 4, 2, 1),
108 nn.BatchNorm2d(nch_g*2),
109 nn.ReLU()
110 ), # (256, 8,8) -> (128, 16, 16)
111 })
112 self.layer9= nn.ModuleDict({
113 'layer0': nn.Sequential(
114 nn.ConvTranspose2d(nch_g*2 , nch_g , 4, 2, 1),
115 nn.BatchNorm2d(nch_g),
116 nn.ReLU()
117 ), # (128, 16, 16) -> (64, 32, 32)
118 })
119 self.layer10 = nn.ModuleDict({
120 'layer0': nn.Sequential(
121 nn.ConvTranspose2d(nch_g,int(nch_g/2) , 4, 2, 1),
122 nn.BatchNorm2d(int(nch_g/2)),
123 nn.Tanh()
124 ), # (64, 32, 32) -> (32, 64, 64)
125 })
126 self.layer11 = nn.ModuleDict({
127 'layer0': nn.Sequential(
128 nn.ConvTranspose2d(int(nch_g/2) , 3 , 4, 2, 1),
129 nn.BatchNorm2d(3),
130 nn.Tanh()
131 ), # (32, 64, 64) -> (3, 128, 128)
132 })
133
134 def forward(self, z):
135
136 for layer in self.layer1.values():
137 z = layer(z)
138 z1 =z
139 for layer in self.layer2.values():
140 z = layer(z)
141 z2 =z
142 for layer in self.layer3.values():
143 z = layer(z)
144 z3=z
145 for layer in self.layer4.values():
146 z = layer(z)
147
148 for layer in self.layer7.values():
149 z = layer(z)
150 z =z+z3
151
152 for layer in self.layer8.values():
153 z = layer(z)
154 z =z+z2
155 for layer in self.layer9.values():
156 z = layer(z)
157 z =z+z1
158
159 for layer in self.layer10.values():
160 z = layer(z)
161
162 for layer in self.layer11.values():
163 z = layer(z)
164 return z
165
166
167def main():
168 #もしGPUがあるならGPUを使用してないならCPUを使用
169 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
170
171 #ネットワークを呼び出し
172 model = SizeDecoder().to(device)
173
174
175 #事前に学習しているモデルがあるならそれを読み込む
176 if pretrained:
177 param = torch.load('./Size_Decoder.pth')
178 model.load_state_dict(param)
179
180 #誤差関数には二乗誤差を使用
181 criterion = nn.MSELoss()
182
183 #更新式はAdamを適用
184 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
185
186 weight_decay=1e-5)
187
188
189 loss_train_list = []
190 loss_test_list= []
191 for epoch in range(num_epochs):
192
193 print(epoch)
194
195 for data in dataloader:
196
197 img, num = data
198 #img --> [batch_size,1,32,32]
199 #imgは元画像
200 #imgをGPUに載せる
201 img = Variable(img).to(device)
202
203 # ===================forward=====================
204
205 #_imgは高解像度画像を低解像度に変換した画像
206
207 _img = (img[:,:,::2, ::2] + img[:,:,1::2, ::2] + img[:,:,::2, 1::2] + img[:,:,1::2, 1::2])/4
208
209
210 _img =_img.to(device)
211
212 #ネットワークの出力結果
213 output = model(_img)
214 print(output.shape)
215 #もし学習するなら
216 if train:
217 #ネットワークの出力と高解像度画像との誤差を損失として学習
218
219
220 # ===================backward====================
221 loss = criterion(output, img)
222 print(loss)
223 #勾配を初期化
224 optimizer.zero_grad()
225
226 #微分値を計算
227 loss.backward()
228
229 #パラメータを更新
230 optimizer.step()
231
232
233 else:#学習しないなら
234 break
235 # ===================log========================
236
237 if train == True:
238 #モデルを保存
239 torch.save(model.state_dict(), './Size_Decoder.pth')
240
241
242 #もし生成画像を保存するなら
243 if save_img:
244 value = int(math.sqrt(batch_size))
245
246 pic = to_img(img.cpu().data)
247 pic = torchvision.utils.make_grid(pic,nrow = value)
248 save_image(pic, './real_image_{}.png'.format(epoch)) #元画像を保存
249
250 pic = to_img_mono(output.cpu().data)
251 pic = torchvision.utils.make_grid(pic,nrow = value)
252 save_image(pic, './image_{}.png'.format(epoch)) #生成画像
253if __name__ == '__main__':
254 main()
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit
おすすめ記事
浮動小数点について調べてみた
2024.09.09