教師なし機械学習「VAE」による連続的な手書き文字の生成
IT技術
認識モデルから生成モデルへ
以前まで、手書き文字の認識は、難しいタスクであると考えられてきました。
しかし、ニューラルネットワークの開発が進んだ2020年現在では、比較的、簡単な課題へと変化しました。
手書き文字の認識モデル
このような「認識モデル」では、手書き文字などのデータの分布を考慮せず、与えられたデータそのものから、直接識別境界を引いていきます。
手書き文字の生成モデル
それに対し、「生成モデル」は与えられたデータ群の分布そのものを求めていきます。
生成モデルのメリット
生成モデルでは、分布が分かっているため、任意のデータを生成できます。
これは、機械学習タスクの弱点の一つ、データ不足を補えます。
また、データを連続的に変化させることが可能となり、モーフィングを行うことができるようになります。
Variational AutoEncoder(VAE)を用いて顔画像の表情・向きを変化させる
例えば、映画やアニメなどで、キャラクターが「普通の顔」から「笑顔」に変化していくアニメーションを作成するとしましょう。
このとき、様々な表情を学習した生成モデルを用いれば、イラストレーターなしに、アニメーションの連続的な変化を作成することができます。
VAE で同一人物の顔画像を学習させ変化させた図
以下は、VAE により、同一人物の顔画像を学習させ、連続的な表情・顔の向きを変化させた図です。
AutoEncoder と Variational AutoEncoder(VAE)
前回、AutoEncoder を用いた次元削減について話しましたが、AutoEncoder では、データと潜在変数の関係を として次元埋め込みをしていました。
「Variational AutoEncoder(VAE)」では、潜在変数に確率的なブレを与えることで、与えられたデータ群の分布を推定し、連続的な画像生成を可能にします。
つまり、VAE は、生成モデルの一つとなります。
AutoEncoder による次元削減の記事
なお、AutoEncoder を用いた次元削減については、以下の記事で解説しています。
Variational AutoEncoder(VAE) の理論
ネットワーク構造
さて、「生成モデルは、データの分布を求めること」と述べました。
今後は、データを
として、その確率分布を とします。
そして、AutoEncoder の時にもお話したように、画像のような高次元データは、ほとんどの画素値が冗長であり、周りの画素値で補完できます。
そのため、実際には、データはより低次元に分布するはずです。
「0~9」の手書き文字画像を3次元の潜在変数に次元削減する
以下は、AutoEncoder により、「0~9」の手書き文字画像を3次元の潜在変数に次元削減したときの例です。
このような低次元の潜在変数を とし、その確率分布を とします。
潜在変数 z をニューラルネットワークで求める
AutoEncoder 同様、潜在変数 をニューラルネットワークによって求めていきます。
しかし、AutoEncoder のように、直接 を求めるわけではありません。
分布 が正規分布に従うとし、ニューラルネットワークは、 をサンプリングするため、正規分布の平均 と、分散 を出力します。
(詳しくは後で解説します)
ニューラルネットワークの出力から平均と分散を求め、 から をサンプリングし、その から入力データ の復元を行います。
そもそもの目的
ネットワーク構造は、比較的簡単にお話しできましたが、損失関数の理解は難しいです。
VAE を実装するにあたって、私たちが考えるべき損失関数とは何でしょうか。
生成モデルの目的は、データの分布、すなわち を求めることでした。
尤度(ゆうど)の最大化をすることで確率分布を推定する
確率分布を求める方法として、尤度の最大化があげられます。
確率分布 が、何らかのパラメータ (普通は平均とか分散をあらわす)で表されているとします。
最尤推定で確率分布を求める際の尤度関数は、同時確率分布に等しく、以下のようになります。
また、尤度関数は、普通対数を取ります。
尤度が確率の掛け算であるため、対数を取ることで、微分の計算が和で済むからです。
この尤度関数を最大化するようなパラメータを求めることで、最もデータの生成分布に近い分布が得られます。
損失関数
対数尤度関数の最大化が目的となったので、さらに式を変形します。
観測データは、潜在変数により生成されたと考えることもできるため、 により を周辺化し、式変形すると以下のようになります。
ここで は、潜在変数 が与えられたときの復元データの分布です。
そのため、ニューラルネットワークでいう Decoder 部は、この分布に基づいてサンプリングされたものが出力されます。
また逆に、 はデータが与えられたときの潜在変数の分布であり、Decoder 同様、ネットワークの Encoder 部分は、この分布に基づいて が観測されます。
AutoEncoder では、データ と潜在変数 の関係を点で求めていました。
しかし、VAE では、 を求めることでデータの低次元な分布を得ることができます。
そして、このを近似することで、対数尤度の最大化を解いていきます。
近似分布の平均と分散を推定する
近似した分布 は、正規分布に従うとし、先ほど述べたように、その平均と分散パラメータを Encoder により推定していきます。
で分布を仮定したら、イェンセンの不等式から対数尤度の下限を求めます。
②の下限は、変分下限(Variational Lower Bound)と言われており、VAE がそう呼ばれることの根拠となっています。
対数尤度の最大化をするには、この下限を押し上げればよいということになります。
対数尤度①と変分下限②の差を計算
ここで、対数尤度①と、右辺に出てきた変分下限②の差を計算してみましょう。
最後の式は、KL ダイバージェンスと呼ばれるもので、分布間の距離を表す指標です。
KL ダイバージェンスは、2つの分布が全く同じであるならば「0」を示します。
つまり、対数尤度は、以下の2つの項で表されることが分かりました。
第2項目は、近似精度がよくなれば、おのずと「0」に近づく非負値です。
そのため、対数尤度の最大化を行うには、変分下限の最大化を行えばよいことになります。
変分下限を式変形していくと、ようやくお目当ての損失関数が得られます。
この式を最大化することで、VAE は学習を進めていきます。
実装上の損失関数
損失関数は、このままでは積分計算が入っているため、積分が入らない形に変形します。
1項目の KL ダイバージェンスについて
まずは、1項目の KL ダイバージェンスについて説明します。
正規分布同士の KL ダイバージェンスは、以下のように計算されます。
は、潜在変数の次元数です。
積分計算をゴリゴリやれば出るので、式は割愛します。
2項目の積分はサンプリング近似を実行
2項目に関して、積分計算を解くのが難しいので、サンプリング近似を行います。
これは、確率分布 に対する期待値計算であるため、有限個のサンプルの平均で近似してしまおうというものです。
また、mnist データは、普通「0, 1」 のデータで表されているため、分布には、ベルヌーイ分布を仮定します。
さらに、学習時のバッチサイズが大きければ、 で十分です。
以上を元に、近似した値を求めていくと、以下のようになります。
ただし、は、潜在変数を入力とする Decoder の関数、つまりは、再構成後の画素値を出力します。
以上から、第2項の損失関数は、以下のように表されます。
実験結果
画像の再構成
AutoEncoder と同様に、再構成をしてみました。
元画像
再構成結果
潜在変数が二次元なので、間違えている部分もありますね。
VAE だからと言って、AutoEncoder よりも表現力が上がるわけではなさそうです。
それに画像がぼやけています。
※ VAE に限らず、ピクセル単位で誤差を計算するモデルは、すべて画像がぼけやすいです。
潜在変数の分布の可視化
潜在変数の散らばりも見てみましょう!
AutoEncoder のように無秩序ではなく、分布が正規分布に引き寄せられるようになりました。
この分布であれば、本来は、非線形であるそれぞれの数字の分布を線形分離することもできそうですね。
連続的な数字画像の生成
では今度は、この二次元に落とし込んだ潜在変数上で、画像を滑らかに変化させてみましょう!
潜在変数の値を徐々に変化させ、その値を Decoder に通して画像を生成しました。
「2」⇒「4」⇒「9」⇒「3」⇒「5」⇒「3」⇒「8」と、左上から右下にジグザグに見ていくことで、数字が徐々に変化していく様子が分かります。
さいごに
確率モデルを扱う VAE を用いることで、画像を連続的に変化をさせることが可能になりました。
今回は、数字のみで実験しましたが、VAE を用いれば、顔画像であろうと衣類であろうと様々なものを連続的に変化させ、モーフィングさせることができます。
面白いので、ぜひ、ご自身の手でも確かめてみてくださいね!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ソースコード全体
1import torch
2import torchvision
3
4import torchvision.datasets as dset
5from torch import nn
6from torch.autograd import Variable
7from torch.utils.data import DataLoader
8from torchvision import transforms
9from torchvision.utils import save_image
10from mpl_toolkits.mplot3d import axes3d
11from torchvision.datasets import MNIST
12import torch.nn.functional as F
13import os
14import pylab
15import math
16import matplotlib.pyplot as plt
17
18#gpuデバイスがあるならgpuを使う
19#ないならcpuで
20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
22#カレントディレクトリにdc_imgというフォルダが作られる
23if not os.path.exists('./dc_img'):
24 os.mkdir('./dc_img')
25
26#カレントディレクトリにdataというフォルダが作られる
27if not os.path.exists('./data'):
28 os.mkdir('./data')
29
30
31num_epochs = 5#エポック数
32batch_size = 30 #バッチサイズ
33learning_rate = 1e-3 #学習率
34
35train =True#Trueなら訓練用データ、Falseなら検証用データを使う
36pretrained =False #学習済みのモデルを使うときはここをTrueに
37latent_dim = 2 #最終的に落とし込む次元数
38save_img = True #元画像と再構成画像を保存するかどうか、バッチサイズが大きいときは保存しない方がいい
39
40def to_img(x):
41 x = x
42 x = x.clamp(0, 1)
43
44 return x
45
46#画像データを前処理する関数
47transform = transforms.Compose([
48
49 transforms.RandomResizedCrop(32, scale=(1.0, 1.0), ratio=(1., 1.)),
50 transforms.ToTensor(),
51 ])
52
53#このコードで自動で./data/以下にm-nistデータがダウンロードされる
54trainset = torchvision.datasets.MNIST(root='./data/',
55 train=True,
56 download=True,
57 transform=transform)
58
59#このコードで自動で./data/以下にm-nistデータがダウンロードされる
60testset = torchvision.datasets.MNIST(root='./data/',
61 train=False,
62 download=True,
63 transform=transform)
64#学習時なら訓練用データを用いる
65if train:
66 dataloader = DataLoader(trainset,
67
68 batch_size=batch_size, shuffle=True)
69
70#テスト時なら検証用データを用いる
71else:
72 dataloader = DataLoader(testset, batch_size=batch_size, shuffle=True)
73
74#ネットワーク定義
75class VAE(nn.Module):
76 def __init__(self, z_dim):
77 super(VAE, self).__init__()
78
79 self.conv1 = nn.Sequential(
80 nn.Conv2d(1, 64, 3, stride=1, padding=1), # b, 64, 32, 32
81 nn.BatchNorm2d(64),
82 nn.LeakyReLU(0,True),
83 nn.MaxPool2d(2) # b, 64, 16, 16
84 )
85
86 self.conv2 = nn.Sequential(
87 nn.Conv2d(64, 128, 3, stride=1, padding=1), # b, 128, 16, 16
88 nn.BatchNorm2d(128),
89 nn.LeakyReLU(0,True),
90 nn.MaxPool2d(2) # b, 128, 8, 8
91 )
92
93 self.conv3 = nn.Sequential(
94 nn.Conv2d(128, 256, 3, stride=1, padding=1), # b, 256, 8, 8
95 nn.BatchNorm2d(256),
96 nn.LeakyReLU(0,True),
97 nn.MaxPool2d(2) # b, 256, 4, 4
98 )
99
100 self.conv4 = nn.Sequential(
101 nn.Conv2d(256, 512, 4, stride=1, padding=0), # b, 512, 1, 1
102 nn.BatchNorm2d(512),
103 nn.LeakyReLU(0,True),
104 )
105
106
107 self.mean = nn.Sequential(
108 nn.Linear(512,latent_dim),# b, 512 ==> b, latent_dim
109 )
110
111 self.var = nn.Sequential(
112 nn.Linear(512,latent_dim),# b, 512 ==> b, latent_dim
113 )
114 self.decoder = nn.Sequential(
115 nn.Linear(latent_dim,512),# b, latent_dim ==> b, 512
116 nn.BatchNorm1d(512),
117 nn.LeakyReLU(0,True),
118 )
119 self.convTrans1 = nn.Sequential(
120 nn.ConvTranspose2d(512, 256, 4, stride=2,padding = 0), # b, 256, 4, 4
121 nn.BatchNorm2d(256),
122 nn.LeakyReLU(0,True),
123 )
124 self.convTrans2 = nn.Sequential(
125 nn.ConvTranspose2d(256, 128, 4, stride=2,padding = 1), # b, 128, 8, 8
126 nn.BatchNorm2d(128),
127 nn.LeakyReLU(0,True),
128 )
129 self.convTrans3 = nn.Sequential(
130 nn.ConvTranspose2d(128, 64, 4, stride=2,padding = 1), # b, 64, 16, 16
131 nn.BatchNorm2d(64),
132 nn.LeakyReLU(0,True),
133 )
134 self.convTrans4 = nn.Sequential(
135 nn.ConvTranspose2d(64, 1, 4, stride=2,padding = 1), # b, 3, 32, 32
136 nn.BatchNorm2d(1),
137 nn.Sigmoid()
138 )
139
140 #Encoderの出力に基づいてzをサンプリングする関数
141 #誤差逆伝搬ができるようにreparameterization trickを用いる
142 def _sample_z(self, mean, var):
143 std = var.mul(0.5).exp_()
144 eps = Variable(std.data.new(std.size()).normal_())
145 return eps.mul(std).add_(mean)
146
147 #Encoder
148 def _encoder(self, x):
149 x = self.conv1(x)
150 x = self.conv2(x)
151 x = self.conv3(x)
152 x = self.conv4(x)
153 x = x.view(-1,512)
154 mean = self.mean(x)
155 var = self.var(x)
156 return mean,var
157
158 #Decoder
159 def _decoder(self, z):
160 z = self.decoder(z)
161 z = z.view(-1,512,1,1)
162 x = self.convTrans1(z)
163 x = self.convTrans2(x)
164 x = self.convTrans3(x)
165 x = self.convTrans4(x)
166 return x
167
168 def forward(self, x):
169 # xは元画像
170 mean,var = self._encoder(x) #Decoderの出力はlog σ^2を想定
171 z = self._sample_z(mean, var) #潜在変数の分布に基づいてzをサンプリング
172 x = self._decoder(z) #サンプリングしたzに対して画像を再構成
173 return x,mean,var,z
174
175 def loss(self, x):
176 mean, var = self._encoder(x) #Decoderの出力はlog σ^2を想定
177
178 KL = -0.5 * torch.mean(torch.sum(1 + var- mean**2 - var.exp())) #KLダイバージェンス
179
180 z = self._sample_z(mean, var) #潜在変数の分布に基づいてzをサンプリング
181 y = self._decoder(z) #サンプリングしたzに対して画像を再構成
182 delta = 1e-7 #logの中身がマイナスにならないように微小な値を与える
183 reconstruction = torch.mean(torch.sum(x * torch.log(y+delta) + (1 - x) * torch.log(1 - y +delta))) #再構成誤差
184 lower_bound = [-KL, reconstruction]
185 return -sum(lower_bound),y,mean,var,z
186
187def main():
188 #ネットワーク宣言
189 model = VAE(latent_dim).to(device)
190
191 #事前に学習したモデルがあるならそれを使う
192 if pretrained:
193 param = torch.load('./conv_Variational_autoencoder_{}dim.pth'.format(latent_dim))
194 model.load_state_dict(param)
195
196
197 #最適化法はAdamを選択
198 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
199
200 weight_decay=1e-5)
201
202 for epoch in range(num_epochs):
203 itr = 0
204 print(epoch)
205 for data in dataloader:
206 itr+=1
207 img, num = data
208 #img --> [batch_size,1,32,32]
209 #num --> [batch_size,1]
210 #imgは画像本体
211 #numは画像に対する正解ラベル
212 #ただし、学習時にnumは使わない
213
214 #imgをデバイスに乗っける
215 img = Variable(img).to(device)
216 # ===================forward=====================
217
218 #outputが再構成画像、latentは次元削減されたデータ
219 if train == False:
220 output,mu,var,latent = model(img)
221
222 #学習時であれば、ネットワークパラメータを更新
223 if train:
224 #lossを計算
225 #元画像と再構成後の画像が近づくように学習
226 loss,output,mu,var,latent = model.loss(img)
227 # ===================backward====================
228 #勾配を初期化
229 optimizer.zero_grad()
230 #微分値を求める
231 loss.backward()
232 #パラメータの更新
233 optimizer.step()
234 print('{} {}'.format(itr,loss))
235 # ===================log========================
236
237 #データをtorchからnumpyに変換
238 z = latent.cpu().detach().numpy()
239 num = num.cpu().detach().numpy()
240
241 #次元数が3の時のプロット
242 if latent_dim == 3:
243 fig = plt.figure(figsize=(15, 15))
244 ax = fig.add_subplot(111, projection='3d')
245 ax.scatter(z[:, 0], z[:, 1], z[:, 2], marker='.', c=num, cmap=pylab.cm.jet)
246 for angle in range(0,360,60):
247 ax.view_init(30,angle)
248 plt.savefig("./fig{}.png".format(angle))
249
250 #次元数が2の時のプロット
251 if latent_dim == 2:
252 plt.figure(figsize=(15, 15))
253 plt.scatter(z[:, 0], z[:, 1], marker='.', c=num, cmap=pylab.cm.jet)
254 plt.colorbar()
255 plt.grid()
256 plt.savefig("./fig.png")
257
258 #元画像と再構成後の画像を保存するなら
259 if save_img:
260 value = int(math.sqrt(batch_size))
261 pic = to_img(img.cpu().data)
262 pic = torchvision.utils.make_grid(pic,nrow = value)
263 save_image(pic, './dc_img/real_image_{}.png'.format(epoch)) #元画像の保存
264
265 pic = to_img(output.cpu().data)
266 pic = torchvision.utils.make_grid(pic,nrow = value)
267 save_image(pic, './dc_img/image_{}.png'.format(epoch)) #再構成後の画像の保存
268
269 #もし学習時ならモデルを保存
270 #バージョン管理は各々で
271 if train == True:
272 torch.save(model.state_dict(), './conv_Variational_autoencoder_{}dim.pth'.format(latent_dim))
273
274if __name__ == '__main__':
275 main()
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit
おすすめ記事
immichを知ってほしい
2024.10.31