「AutoEncoder」から見る機械学習の次元削減の意味
IT技術
AutoEncoder から見る機械学習の次元削減の意味とは
「オッカムの剃刀」という言葉をご存知ですか?
「オッカムの剃刀」は、何か現象を説明する際、仮定は少ない方が無駄がなく分かりやすいというものです。
これは、14世紀のスコラ哲学者オッカムの有名な考え方です。
「オッカムの剃刀」とは?
例えば、皆さんの周りに、会話をするときに「え、その話いる?」といった話し方をする人はいませんか?
実は、先ほどの「14世紀のスコラ哲学者オッカムの有名な考え方です。」という言葉の中にも、剃刀で削ぎ落とすべき無駄な単語が含まれています。
「スコラ哲学」という単語にばかり目がいってしまい、強調したいはずの「オッカム」が霞んでしまいます。
このように、無駄のない説明の方が優れているという考え方を「オッカムの剃刀」といいます。
物理学でも用いられる「オッカムの剃刀」
物理学にも「オッカムの剃刀」は多く適用されており、物理現象は少ない変数で表されることがほとんどです。
中でも、有名なニュートンの運動方程式は驚くほど単純です。
あれだけ単純なものにも、オッカムの剃刀という考え方が採用されていたのです。
「オッカムの剃刀」から「次元削減」へ
解析分野でも「オッカムの剃刀」は採用されています。
解析分野では、オッカムの剃刀の考え方の元、データの次元を減らす「次元削減」というものをよく行います。
スイスロール
以下は「スイスロール」と呼ばれる三次元データです。
実はこのスイスロール、三次元とはいいながらも、データの分布は偏っており、二次元に引き延ばしてしまうことができます。
三次元の時は見る角度によって、データの分布が重なっていたところがありました。
「次元削減」により、二次元では、データ同士の分布を切り離すことができていることが分かります。
そのため、新たなデータが入ってきた際に、「そのデータはどの分布に含まれるか?」といった判定がとても楽になります。
これが「次元圧縮」という手法で、解析に必要のない余分な要素を省くことができます。
「次元削減」と「機械学習」の関係
さて、次元削減について説明していきましたが、いよいよ「機械学習」との関係を見ていきたいと思います!
最も単純な削減法「主成分分析」
次元削減にも色々手法があり、最も単純な削減法は「主成分分析」です。
しかし主成分分析では、スイスロールのような非線形データ(並進、回転、拡大縮小以外の変形)に対してはうまく動作しません。
「ニューラルネットワーク」の根幹は次元削減
一方、「ニューラルネット」では、3層以上のパーセプトロンを持つならば、任意の非線形問題を解くことができます。
そのため、多くのニューラルネットワークでは、中間層のユニット数が入力に比べて小さくなっている場合がほとんどであり、入力層から中間層への次元削減をしていると言えます。
すなわち、ニューラルネットワークの根幹は次元削減であるということです。
「画像」と「次元削減」
では、画像空間における次元圧縮を学んでいきましょう!
画像空間というのは、縦×横の画像サイズだけの次元数(画素値)を持つため、一般的にかなり高次元です。
しかし、画像空間における分布というのは、かなり低次元な分布に偏っていると言われています。(多様体仮説)
そのため、次元削減することにより、画像の低次元な分布を得ることが期待できます。
そして、その画像空間の多くは非線形に分布しています。(本当はもっと高次元です。スイスロールもその一例です。)
「非線形」な空間の「次元削減」ときたら、「ニューラルネットワーク」でしょう。
そのため、画像空間の次元削減には、多くの場合ニューラルネットワークが用いられます。
「AutoEncoder」の構造
では、どのようなニューラルネットワークを作成すれば良いかを考えていきましょう!
まずは、内容をシンプルにするために、CNN でない「普通のニューラルネットワーク」の場合を考えていきます!
入力は縦×横の画素数だけのユニット数を持ちます。
次元削減をしたいので、ユニット数を中間層では入力層より減らしていきます。
この中間層のユニット数が、最終的に次元圧縮した次元数となります。
画像の「多様体仮説」を思い出す
ですが、これだけではニューラルネットワークは「学習」が出来ません。
ここで、画像の「多様体仮説」を思い出します。
画像の分布は、低次元に分布しているということでした。
低次元な分布からでも高次元な画像を生成することは可能なはずです!
そのため、出力層は元の入力層と同じユニット数にします。
ネットワーク構造
よって、以下のようなネットワーク構造になります。
そして、学習は入力と出力が同じになるように、ニューラルネットワークのパラメータを最適化していきます。
誤差関数は、二乗誤差で問題ないでしょう。
これが俗に言う、「自己符号化器」または「AutoEncoder」です。
実験に用いたネットワーク構造
今回使ったモデルは、pytorch 公式のサンプルを改良し、「BatchNormlization」や「LeakyLeRU」を加えた CNN 構造を取ります。
1#ネットワーク定義
2class autoencoder(nn.Module):
3
4 def __init__(self):
5
6 super(autoencoder, self).__init__()
7
8 self.encoder = nn.Sequential(
9
10 nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10
11
12 nn.LeakyReLU(0,True),
13
14 nn.BatchNorm2d(16),
15
16 nn.MaxPool2d(2, stride=2), # b, 16, 5, 5
17
18 nn.Conv2d(16, 32, 3, stride=2, padding=1), # b, 32, 3, 3
19
20 nn.LeakyReLU(0,True),
21
22 nn.BatchNorm2d(32),
23
24 nn.MaxPool2d(2, stride=1), # b, 32, 2, 2
25
26
27
28 nn.Conv2d(32, 64, 2, stride=1, padding=0), # b, 64, 1, 1
29
30 nn.LeakyReLU(0,True),
31
32 nn.BatchNorm2d(64),
33
34 )
35
36 self.fc1 = nn.Sequential(
37
38
39 nn.Linear(64,latent_dim),
40
41 nn.Tanh() #潜在変数を[-1,1]にするためにハイパボリックタンジェントを活性化関数に
42 )
43
44 self.fc2 = nn.Sequential(
45 nn.Linear(latent_dim,64),
46
47 nn.LeakyReLU(0,True)
48
49
50 )
51
52 self.decoder = nn.Sequential(
53
54 nn.ConvTranspose2d(64, 32, 2, stride=1), # b, 32, 2, 2
55
56 nn.LeakyReLU(0,True),
57
58 nn.BatchNorm2d(32),
59
60 nn.ConvTranspose2d(32, 16, 3, stride=2), # b, 16, 5, 5
61
62 nn.LeakyReLU(0,True),
63
64 nn.BatchNorm2d(16),
65
66 nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15
67
68 nn.LeakyReLU(0,True),
69
70 nn.BatchNorm2d(8),
71
72 nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28
73
74 nn.Tanh()
75
76 )
77
78
79
80 def forward(self, x):
81
82 # ===================encoder=====================
83 # xは元画像
84
85 latent = self.encoder(x) #CNNから得られた特徴量
86
87 latent = latent.view(-1,64) #特徴量をベクトル化
88
89 latent = self.fc1(latent) #CNNから得られた特徴量をさらに低次元の特徴に写像
90
91 pre_latent = latent #次元削減により得られた潜在変数
92
93 # ===================decoder=====================
94
95 latent = self.fc2(latent)
96
97 latent = latent.view(-1,64,1,1)
98
99 x = self.decoder(latent) #再構成画像
100
101 return x,pre_latent
実験結果
二次元まで圧縮した例
以下は、実際に数字画像を二次元まで次元削減した時の、再構成画像の例です。
元の画像
再構成した画像
そもそも生成できていない数字が存在したり間違えているものも多数あります。
そのため、これらの数字の分布を確認してみます。
数字の分布
「9」と「4」、「7」の分布が完全にかぶっています。
特に「4」は、そもそも生成できていないことが分かります。
他にも「3」と「5」と「8」も同様で、うまくこれらの分布を離すことができず、誤生成につながったと考えることができます。
つまり、今回のネットワークモデルでは、m-nist を二次元で埋め込むことに失敗しました。
そこで三次元でも、これらの分布を可視化してみます。
三次元まで圧縮した例
まずは、再構成前後の画像例をご覧ください。
元の
二次元の時と比べて、表現できる数字の数が増えたのが、ぱっと見で分かるかと思います。
ただ、「9」と「7」の組み合わせや、「4」と「9」等に間違いが多いことが分かります。
実際に、これらの数字の分布が近いかどうかを見てみます。
数字の分布
確かにこれらの分布は近く、特に「9」関連で間違いが多そうな気がします。
また、見方によっては分離できていないように見えていても、違う角度から見ると分離できている「3」と「5」のような組み合わせもあり、大変興味深いです。
三次元にまで削減すれば、各数字のある程度の分布を分けることができました。
さいごに
今回は、手書き文字を用いた「AutoEncoder」を利用して、「次元圧縮」について解説してみました!
削減した次元が何を表しているのかは、自分で解釈する必要があります。
「曲がり具合」を表しているのか、それとも「線の太さ」を表しているのか?
これは、「主成分分析」であっても同じです。
これがデータサイエンスの難しい所です。
次元は減らし過ぎもよくない
また、次元数は減らし過ぎてもいけません。
こればかりは、試してみて一番良いユニット数を自分たちで見つけるしかありません。
ここがニューラルネットワークの弱点でもあります。
次回は、オートエンコーダと深い関係のある、「VAE」という手法をお話ししていきます。
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
全体のソースコード
1import torch
2
3import torchvision
4
5
6import torchvision.datasets as dset
7from torch import nn
8
9from torch.autograd import Variable
10
11from torch.utils.data import DataLoader
12
13from torchvision import transforms
14
15from torchvision.utils import save_image
16
17from mpl_toolkits.mplot3d import axes3d
18
19from torchvision.datasets import MNIST
20
21import os
22
23import pylab
24
25import matplotlib.pyplot as plt
26
27
28#カレントディレクトリにdc_imgというフォルダが作られる
29
30if not os.path.exists('./dc_img'):
31
32 os.mkdir('./dc_img')
33
34#カレントディレクトリにdataというフォルダが作られる
35
36if not os.path.exists('./data'):
37
38 os.mkdir('./data')
39
40
41
42num_epochs = 100 #エポック数
43
44batch_size = 100 #バッチサイズ
45
46learning_rate = 1e-2 #学習率
47
48
49train = True #Trueなら訓練用データ、Falseなら検証用データを使う
50
51pretrained = False #学習済みのモデルを使うときはここをTrueに
52
53latent_dim = 3 #最終的に落とし込む次元数
54
55save_img = False #元画像と再構成画像を保存するかどうか、バッチサイズが大きいときは保存しない方がいい
56
57
58
59def to_img(x):
60
61 x = 0.5 * (x + 1)
62
63 x = x.clamp(0, 1)
64
65 x = x.view(x.size(0), 1, 28, 28)
66
67 return x
68
69#画像データを前処理する関数
70transform = transforms.Compose(
71 [transforms.ToTensor(),
72 transforms.Normalize((0.5, ), (0.5, ))])
73
74#このコードで自動で./data/以下にm-nistデータがダウンロードされる
75trainset = torchvision.datasets.MNIST(root='./data/',
76 train=True,
77 download=True,
78 transform=transform)
79
80#このコードで自動で./data/以下にm-nistデータがダウンロードされる
81testset = torchvision.datasets.MNIST(root='./data/',
82 train=False,
83 download=True,
84 transform=transform)
85
86#学習時なら訓練用データを用いる
87if train:
88 dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
89
90#テスト時なら検証用データを用いる
91else:
92 dataloader = DataLoader(testset, batch_size=batch_size, shuffle=True)
93
94
95
96#ネットワーク定義
97class autoencoder(nn.Module):
98
99 def __init__(self):
100
101 super(autoencoder, self).__init__()
102
103 self.encoder = nn.Sequential(
104
105 nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10
106
107 nn.LeakyReLU(0,True),
108
109 nn.BatchNorm2d(16),
110
111 nn.MaxPool2d(2, stride=2), # b, 16, 5, 5
112
113 nn.Conv2d(16, 32, 3, stride=2, padding=1), # b, 32, 3, 3
114
115 nn.LeakyReLU(0,True),
116
117 nn.BatchNorm2d(32),
118
119 nn.MaxPool2d(2, stride=1), # b, 32, 2, 2
120
121
122
123 nn.Conv2d(32, 64, 2, stride=1, padding=0), # b, 64, 1, 1
124
125 nn.LeakyReLU(0,True),
126
127 nn.BatchNorm2d(64),
128
129 )
130
131 self.fc1 = nn.Sequential(
132
133
134 nn.Linear(64,latent_dim),
135
136 nn.Tanh() #潜在変数を[-1,1]にするためにハイパボリックタンジェントを活性化関数に
137 )
138
139 self.fc2 = nn.Sequential(
140 nn.Linear(latent_dim,64),
141
142 nn.LeakyReLU(0,True)
143
144
145 )
146
147 self.decoder = nn.Sequential(
148
149 nn.ConvTranspose2d(64, 32, 2, stride=1), # b, 32, 2, 2
150
151 nn.LeakyReLU(0,True),
152
153 nn.BatchNorm2d(32),
154
155 nn.ConvTranspose2d(32, 16, 3, stride=2), # b, 16, 5, 5
156
157 nn.LeakyReLU(0,True),
158
159 nn.BatchNorm2d(16),
160
161 nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15
162
163 nn.LeakyReLU(0,True),
164
165 nn.BatchNorm2d(8),
166
167 nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28
168
169 nn.Tanh()
170
171 )
172
173
174
175 def forward(self, x):
176
177 # ===================encoder=====================
178 # xは元画像
179
180 latent = self.encoder(x) #CNNから得られた特徴量
181
182 latent = latent.view(-1,64) #特徴量をベクトル化
183
184 latent = self.fc1(latent) #CNNから得られた特徴量をさらに低次元の特徴に写像
185
186 pre_latent = latent #次元削減により得られた潜在変数
187
188 # ===================decoder=====================
189
190 latent = self.fc2(latent)
191
192 latent = latent.view(-1,64,1,1)
193
194 x = self.decoder(latent) #再構成画像
195
196 return x,pre_latent
197
198
199def main():
200 #gpuデバイスがあるならgpuを使う
201 #ないならcpuで
202 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
203
204 #ネットワーク宣言
205 model = autoencoder().to(device)
206
207 #事前に学習したモデルがあるならそれを使う
208 if pretrained:
209 param = torch.load('./conv_autoencoder_{}dim.pth'.format(latent_dim))
210 model.load_state_dict(param)
211
212 #誤差関数は二乗誤差で
213 criterion = nn.MSELoss()
214
215 #最適化法はAdamを選択
216 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
217
218 weight_decay=1e-5)
219
220
221
222 for epoch in range(num_epochs):
223
224 print(epoch)
225
226 for data in dataloader:
227
228 img, num = data
229 #img --> [batch_size,1,28,28]
230 #num --> [batch_size,1]
231 #imgは画像本体
232 #numは画像に対する正解ラベル
233 #ただし、学習時にnumは使わない
234
235 #imgをデバイスに乗っける
236 img = Variable(img).to(device)
237
238 # ===================forward=====================
239
240 #outputが再構成画像、latentは次元削減されたデータ
241 output,latent = model(img)
242
243 #学習時であれば、ネットワークパラメータを更新
244 if train:
245 #lossを計算
246 #元画像と再構成後の画像が近づくように学習
247 loss = criterion(output, img)
248
249 # ===================backward====================
250 #勾配を初期化
251 optimizer.zero_grad()
252
253 #微分値を求める
254 loss.backward()
255
256 #パラメータの更新
257 optimizer.step()
258
259 print('{}'.format(loss))
260
261
262
263
264 # ===================log========================
265
266 #データをtorchからnumpyに変換
267 z = latent.cpu().detach().numpy()
268 num = num.cpu().detach().numpy()
269
270 #次元数が3の時のプロット
271 if latent_dim == 3:
272 fig = plt.figure(figsize=(15, 15))
273 ax = fig.add_subplot(111, projection='3d')
274 ax.scatter(z[:, 0], z[:, 1], z[:, 2], marker='.', c=num, cmap=pylab.cm.jet)
275 for angle in range(0,360,60):
276 ax.view_init(30,angle)
277 plt.savefig("./fig{}.png".format(angle))
278
279 #次元数が2の時のプロット
280 if latent_dim == 2:
281 plt.figure(figsize=(15, 15))
282 plt.scatter(z[:, 0], z[:, 1], marker='.', c=num, cmap=pylab.cm.jet)
283 plt.colorbar()
284 plt.grid()
285 plt.savefig("./fig.png")
286
287 #元画像と再構成後の画像を保存するなら
288 if save_img:
289 pic = to_img(img.cpu().data)
290
291 save_image(pic, './dc_img/real_image_{}.png'.format(epoch)) #元画像の保存
292
293 pic = to_img(output.cpu().data)
294
295 save_image(pic, './dc_img/image_{}.png'.format(epoch)) #再構成後の画像の保存
296
297 #もし学習時ならモデルを保存
298 #バージョン管理は各々で
299 if train == True:
300 torch.save(model.state_dict(), './conv_autoencoder_{}dim.pth'.format(latent_dim))
301
302if __name__ == '__main__':
303 main()
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit
おすすめ記事
immichを知ってほしい
2024.10.31