【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
IT技術
【後編】PyTorchでCIFAR-10をCNNに学習させる
【前編】の続きとなります。
引き続き、PyTorch(パイトーチ)で畳み込みニューラルネットワーク(CNN)を実装していきたいと思います。
今回は、学習結果からとなります!
前編の記事はこちら
学習結果
学習が終わりましたが、やはり「MNIST」と違って学習に時間がかかりますね!
ですが、50エポックなので数十分で終わるかと思います。(マシンスペックに依存しますが...)
訓練ロスと訓練 / テスト精度
学習によって得られた、『訓練ロス』『訓練 / テスト精度』から見てみましょう。
まだ、精度は「70%程度」と低いですが、しっかり学習できていそうですね!
畳み込み層のフィルタ
ちなみに、学習後の畳み込み層のフィルタを見てみると
学習前と比べて、何かしらフィルタに模様が見えてきましたね。
よく見ると、斜め方向に対応するフィルタや、横方向に対応するフィルタが見受けられますが、まだはっきりとは分かりませんね。
またグラフを見ると、学習回数を増やせば、まだ精度は伸びそうな雰囲気があります。
「もっともっと学習を増やしてみましょう!」
…と言いたいところですが、そうなると学習に膨大な時間がかかってしまいそうです。
GPUを使ってみる
「CUDA(Compute Unified Device Architecture:クーダ)」が使用可能であれば、PyTorchでは、簡単にGPUに演算を行わせることができます。
メイン処理部の冒頭に、以下を加筆して下さい。
1if __name__ == '__main__':
2 epoch = 300 # 今回は300エポック!
3
4 loader = load_cifar10()
5 classes = ('plane', 'car', 'bird', 'cat', 'deer',
6 'dog', 'frog', 'horse', 'ship', 'truck')
7
8 net: MyCNN = MyCNN()
9 criterion = torch.nn.CrossEntropyLoss() # ロスの計算
10 optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
11
12 """ 以下を加筆 """
13 # もしGPUが使えるなら使う
14 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15 net.to(device)
16 print(device)
torch.cuda.is_available() は、「CUDA」が使用可能ならTrueを返すといった、GPU を使用できるかを簡単に確認できる関数です。
その後、ネットワークを to() でデバイスに投げるだけです。
しかし、これだけではエラーを吐かれてしまいます。
GPU に使用するデータセットを投げる
使用するデータセットも、GPU に投げる必要があるので、各データローダーのループに、以下のように加筆して下さい。
1# ...
2for i, (images, labels) in enumerate(loader['train']):
3 images = images.to(device) # to GPU?
4 labels = labels.to(device)
5 # ...
matplotlib で重みを描画するには
最後に、重みを描画する「matplotlib」では、CPU で描画するため、その際は逆に CPU に投げる必要があります。
以下のように書けば OK です!
1# ...
2plt.imshow(weight.data.to('cpu').numpy(), cmap='winter') # to('cpu')を追加!
3# ...
これで、GPU を使用する準備が整いました!
早速学習させてみましょう!
ちなみに筆者の環境は、GPU は「NVIDIA GeForce GTX Titan Blak 6GB」で、「CUDA」はバーション9.2で動作確認をしています。
学習結果 (300エポック)
さすがGPUを使うと、目に見えて学習が早いです!
時間はしっかりと測定していませんが、1.5 ~ 2倍くらい早いです。
では早速、300エポックの学習結果を見てみましょう!
なんと、訓練精度は「約100%」になりましたが、テスト精度は「60%程度」になってしまいました。
「過学習」が起こってしまいました!
【過学習とは?】
過学習とは、このように訓練データに適応しすぎて、テストデータなどに対する性能、いわゆる汎化性能が低下してしまう現象を言います。
Dropoutを導入してみる
それでは、過学習対策としてメジャーな「Dropout」を導入してみましょう。
「Dropout」とは、簡単に言えば、学習時に一部のニューロンを、わざと非活性化させ、訓練データに適合しすぎないようにする手法です。
1class MyCNN(torch.nn.Module):
2 def __init__(self):
3 super(MyCNN, self).__init__()
4 self.conv1 = torch.nn.Conv2d(3, # チャネル入力
5 6, # チャンネル出力
6 5, # カーネルサイズ
7 1, # ストライド (デフォルトは1)
8 0, # パディング (デフォルトは0)
9 )
10 self.conv2 = torch.nn.Conv2d(6, 16, 5)
11
12 self.pool = torch.nn.MaxPool2d(2, 2) # カーネルサイズ, ストライド
13
14 self.dropout1 = torch.nn.Dropout2d(p=0.3) # [new] Dropoutを追加してみる
15
16 self.fc1 = torch.nn.Linear(16 * 5 * 5, 120) # 入力サイズ, 出力サイズ
17 self.dropout2 = torch.nn.Dropout(p=0.5) # [new] Dropoutを追加してみる
18 self.fc2 = torch.nn.Linear(120, 84)
19 self.fc3 = torch.nn.Linear(84, 10)
20
21 def forward(self, x):
22 x = f.relu(self.conv1(x))
23 x = self.pool(x)
24 x = f.relu(self.conv2(x))
25 x = self.pool(x)
26 x = self.dropout1(x) # [new] Dropoutを追加
27 x = x.view(-1, 16 * 5 * 5) # 1次元データに変えて全結合層へ
28 x = f.relu(self.fc1(x))
29 x = self.dropout2(x) # [new] Dropoutを追加
30 x = f.relu(self.fc2(x))
31 x = self.fc3(x)
32
33 return x
追加してみました!
さて、結果は、どうなるでしょうか!?
実験結果 (300エポック + Dropout)
過学習はなくなりましたね!
ただ、やはり精度は「70%程度」といったところでしょうか。
もしかしたら、これが「LeNet」の限界なのかもしれません...。
他にも「Batch Normalization(バッチ正規化)」や、「Data Augmentation(データ拡張)」などの手法を用いれば、過学習を抑制しつつ精度向上が見込めるかもしれません。
ですが、本記事ではここまでとします。
ここまできたら、もう少し深い層の、畳み込みニューラルネットワーク (DCNN: Deep Convolutional Neural Networks)を構築した方が良いでしょう。
フィルタの重みの学習結果
ちなみに、フィルタの重みの学習結果も載せておきますが、50エポックの時と、大した差はありませんね。
やや、模様が明確になったような気もします(笑)
さいごに
長丁場になりましたが、今回は、PyTorchで畳み込みニューラルネットワークを構築し、カラー画像の「CIFAR10」を学習させてみました。
また、ネットワークの内部を観察したり、いろいろな考察もしてみたので、機械学習初学者の皆さまの参考になれば幸いです。
次回は、もう少し複雑なネットワークで試してみようと考えているのでお楽しみに!
ソースコード
1import torch
2import torch.nn.functional as f
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import matplotlib.pyplot as plt
6from tqdm import tqdm
7
8
9class MyCNN(torch.nn.Module):
10 def __init__(self):
11 super(MyCNN, self).__init__()
12 self.conv1 = torch.nn.Conv2d(3, # チャネル入力
13 6, # チャンネル出力
14 5, # カーネルサイズ
15 1, # ストライド (デフォルトは1)
16 0, # パディング (デフォルトは0)
17 )
18 self.conv2 = torch.nn.Conv2d(6, 16, 5)
19
20 self.pool = torch.nn.MaxPool2d(2, 2) # カーネルサイズ, ストライド
21
22 self.dropout1 = torch.nn.Dropout2d(p=0.3) # [new] Dropoutを追加してみる
23
24 self.fc1 = torch.nn.Linear(16 * 5 * 5, 120) # 入力サイズ, 出力サイズ
25 self.dropout2 = torch.nn.Dropout(p=0.5) # [new] Dropoutを追加してみる
26 self.fc2 = torch.nn.Linear(120, 84)
27 self.fc3 = torch.nn.Linear(84, 10)
28
29 def forward(self, x):
30 x = f.relu(self.conv1(x))
31 x = self.pool(x)
32 x = f.relu(self.conv2(x))
33 x = self.pool(x)
34 x = self.dropout1(x) # [new] Dropoutを追加
35 x = x.view(-1, 16 * 5 * 5) # 1次元データに変えて全結合層へ
36 x = f.relu(self.fc1(x))
37 x = self.dropout2(x) # [new] Dropoutを追加
38 x = f.relu(self.fc2(x))
39 x = self.fc3(x)
40
41 return x
42
43 def plot_conv1(self, prefix_num=0):
44 weights1 = self.conv1.weight
45 weights1 = weights1.reshape(3*6, 5, 5)
46
47 for i, weight in enumerate(weights1):
48 plt.subplot(3, 6, i + 1)
49 plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')
50 plt.tick_params(labelbottom=False,
51 labelleft=False,
52 labelright=False,
53 labeltop=False,
54 bottom=False,
55 left=False,
56 right=False,
57 top=False)
58
59 plt.savefig('img/{}_conv1.png'.format(prefix_num))
60 plt.close()
61
62 def plot_conv2(self, prefix_num=0):
63 weights2 = self.conv2.weight
64 weights2 = weights2.reshape(6*16, 5, 5)
65
66 for i, weight in enumerate(weights2):
67 plt.subplot(6, 16, i + 1)
68 plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')
69 plt.tick_params(labelbottom=False,
70 labelleft=False,
71 labelright=False,
72 labeltop=False,
73 bottom=False,
74 left=False,
75 right=False,
76 top=False)
77
78 plt.savefig('img/{}_conv2.png'.format(prefix_num))
79 plt.close()
80
81
82def load_cifar10(batch=128):
83 train_loader = DataLoader(
84 datasets.CIFAR10('./data',
85 train=True,
86 download=True,
87 transform=transforms.Compose([
88 transforms.ToTensor(),
89 transforms.Normalize(
90 [0.5, 0.5, 0.5], # RGB 平均
91 [0.5, 0.5, 0.5] # RGB 標準偏差
92 )
93 ])),
94 batch_size=batch,
95 shuffle=True
96 )
97
98 test_loader = DataLoader(
99 datasets.CIFAR10('./data',
100 train=False,
101 download=True,
102 transform=transforms.Compose([
103 transforms.ToTensor(),
104 transforms.Normalize(
105 [0.5, 0.5, 0.5], # RGB 平均
106 [0.5, 0.5, 0.5] # RGB 標準偏差
107 )
108 ])),
109 batch_size=batch,
110 shuffle=True
111 )
112
113 return {'train': train_loader, 'test': test_loader}
114
115
116if __name__ == '__main__':
117 epoch = 300
118
119 loader = load_cifar10()
120 classes = ('plane', 'car', 'bird', 'cat', 'deer',
121 'dog', 'frog', 'horse', 'ship', 'truck')
122
123 net: MyCNN = MyCNN()
124 criterion = torch.nn.CrossEntropyLoss() # ロスの計算
125 optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
126
127 # もしGPUが使えるなら使う
128 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
129 net.to(device)
130 print(device)
131
132 # 学習前のフィルタの可視化
133 net.plot_conv1()
134 net.plot_conv2()
135
136 history = {
137 'train_loss': [],
138 'train_acc': [],
139 'test_acc': []
140 }
141
142 for e in range(epoch):
143 net.train()
144 loss = None
145 for i, (images, labels) in enumerate(loader['train']):
146 images = images.to(device) # to GPU?
147 labels = labels.to(device)
148
149 optimizer.zero_grad()
150 output = net(images)
151 loss = criterion(output, labels)
152 loss.backward()
153 optimizer.step()
154
155 if i % 10 == 0:
156 print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1,
157 (i + 1) * 128,
158 loss.item())
159 )
160
161 # 学習過程でのフィルタの可視化
162 # net.plot_conv1(e+1)
163 # net.plot_conv2(e+1)
164
165 history['train_loss'].append(loss.item())
166
167 net.eval()
168 correct = 0
169 with torch.no_grad():
170 for i, (images, labels) in enumerate(tqdm(loader['train'])):
171 images = images.to(device) # to GPU?
172 labels = labels.to(device)
173
174 outputs = net(images)
175 _, predicted = torch.max(outputs.data, 1)
176 correct += (predicted == labels).sum().item()
177
178 acc = float(correct / 50000)
179 history['train_acc'].append(acc)
180
181 correct = 0
182 with torch.no_grad():
183 for i, (images, labels) in enumerate(tqdm(loader['test'])):
184 images = images.to(device) # to GPU?
185 labels = labels.to(device)
186
187 outputs = net(images)
188 _, predicted = torch.max(outputs.data, 1)
189 correct += (predicted == labels).sum().item()
190
191 acc = float(correct / 10000)
192 history['test_acc'].append(acc)
193
194 # 学習前のフィルタの可視化
195 net.plot_conv1(300)
196 net.plot_conv2(300)
197
198 # 結果をプロット
199 plt.plot(range(1, epoch+1), history['train_loss'])
200 plt.title('Training Loss [CIFAR10]')
201 plt.xlabel('epoch')
202 plt.ylabel('loss')
203 plt.savefig('img/cifar10_loss.png')
204 plt.close()
205
206 plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc')
207 plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc')
208 plt.title('Accuracies [CIFAR10]')
209 plt.xlabel('epoch')
210 plt.ylabel('accuracy')
211 plt.legend()
212 plt.savefig('img/cifar10_acc.png')
213 plt.close()
前編の記事はこちら
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
関連記事
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit