【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
IT技術
PyTorchでCIFAR-10をCNNに学習させる
前回の『【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる』に引き続き、PyTorchで機械学習を学んでいきましょう!
今回は、PyTorchで畳み込みニューラルネットワーク(CNN)を実装していきます。
ちなみに、公式ドキュメントにも同じような実装が紹介されているようです。
ですが、本記事では、日本語で分かりやすく詳細に解説していきたいと思っています。
さらに最後には、ネットワークの内部を可視化してみたり、GPUを使用してみたりと、様々な実験が含まれている記事になっています!
ですので、ぜひ最後まで読んでみてください!
インポートされているモジュール
これから実装するコードは、以下のモジュールがあらかじめインポートされています。
1import torch
2import torch.nn.functional as f
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import matplotlib.pyplot as plt
6from tqdm import tqdm
それでは、実際に実装していきましょう!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
CIFAR10の準備
今回用いるデータセットの「CIFAR10 (サイファー10)」 は、32×32 のカラー画像からなるデータセットで、その名の通り10クラスあります。
「MNIST」は 28×28 のグレースケール画像なので、「CIFER10」の方が情報量は数倍多く、学習は難しいです。
畳み込みニューラルネットワークを学んだり、研究しているほとんどの人は、このデータセットをカラー画像の簡易なベンチマークとして使用しています。
それでは、このデータセットを読み込む関数を作っていきます。
CIFER10を読み込む関数を作る
1def load_cifar10(batch=128):
2 train_loader = DataLoader(
3 datasets.CIFAR10('./data',
4 train=True,
5 download=True,
6 transform=transforms.Compose([
7 transforms.ToTensor(),
8 transforms.Normalize(
9 [0.5, 0.5, 0.5], # RGB 平均
10 [0.5, 0.5, 0.5] # RGB 標準偏差
11 )
12 ])),
13 batch_size=batch,
14 shuffle=True
15 )
16
17 test_loader = DataLoader(
18 datasets.CIFAR10('./data',
19 train=False,
20 download=True,
21 transform=transforms.Compose([
22 transforms.ToTensor(),
23 transforms.Normalize(
24 [0.5, 0.5, 0.5], # RGB 平均
25 [0.5, 0.5, 0.5] # RGB 標準偏差
26 )
27 ])),
28 batch_size=batch,
29 shuffle=True
30 )
31
32 return {'train': train_loader, 'test': test_loader}
「MNIST」で行った時とほとんど変わりませんが、今回データの正規化として、各カラーチャネルの平均と標準偏差を「0.5」になるようにしています。
画像データセットでよくある「正規化」ですので、覚えておきましょう!
動作確認
それでは、試しに動作確認をしてみましょう!
1if __name__ == '__main__':
2 loader = load_cifar10()
3 classes = ('plane', 'car', 'bird', 'cat', 'deer',
4 'dog', 'frog', 'horse', 'ship', 'truck') # CIFAR10のクラス
5
6 for images, labels in loader['train']:
7 print(images.shape) # torch.Size([128, 3, 32, 32])
8
9 # 試しに50枚を 5x10 で見てみる
10 for i in range(5):
11 for j in range(10):
12 image = images[i*10+j] / 2 + 0.5
13 image = image.numpy()
14 plt.subplot(5, 10, i*10+j + 1)
15 plt.imshow(np.transpose(image, (1, 2, 0))) # matplotlibではチャネルは第3次元
16
17 # 対応するラベル
18 plt.title(classes[int(labels[i*10+j])])
19
20 # 軸目盛や値はいらないので消す
21 plt.tick_params(labelbottom=False,
22 labelleft=False,
23 labelright=False,
24 labeltop=False,
25 bottom=False,
26 left=False,
27 right=False,
28 top=False)
29
30 plt.show()
31 break
画像の描画はイテレーションループを使ってますが、もちろん iter() でもOKです!
PyTorch のローダーを使って取得したデータセットは、イテレータで取得したときに [バッチサイズ, チャネル, (画像のシェイプ)]というテンソル型の画像と、[バッチサイズ] というテンソル型のラベルを返します。
描画結果
実際に50枚描画してみると、以下のようになりました。
32×32 なので粗い画像ですが、画像とラベルが一致していそうですね。
これらの画像を、今から畳み込みニューラルネットワーク(CNN)に学習させていきます!
CNNの構築
それでは、早速ネットワークを構築していきます。
今回は以下のような、「LeNet」 と呼ばれる畳み込みニューラルネットワーク(CNN)をベースに構築し、学習させていきます。
「LeNet」の構成
「LeNet」が提案されたのは1998年と古いものですが、畳み込みニューラルネットワーク(CNN)という名を有名にさせたネットワークです。
このネットワークを、ほとんどそのまま実装してみると以下のようになります。
(実際には活性化関数など、一部元論文と異なります)
実装
1class MyCNN(torch.nn.Module):
2 def __init__(self):
3 super(MyCNN, self).__init__()
4 self.conv1 = torch.nn.Conv2d(3, # チャネル入力
5 6, # チャンネル出力
6 5, # カーネルサイズ
7 1, # ストライド (デフォルトは1)
8 0, # パディング (デフォルトは0)
9 )
10 self.conv2 = torch.nn.Conv2d(6, 16, 5)
11
12 self.pool = torch.nn.MaxPool2d(2, 2) # カーネルサイズ, ストライド
13
14 self.fc1 = torch.nn.Linear(16 * 5 * 5, 120) # 入力サイズ, 出力サイズ
15 self.fc2 = torch.nn.Linear(120, 84)
16 self.fc3 = torch.nn.Linear(84, 10)
17
18 def forward(self, x):
19 x = f.relu(self.conv1(x))
20 x = self.pool(x)
21 x = f.relu(self.conv2(x))
22 x = self.pool(x)
23 x = x.view(-1, 16 * 5 * 5) # 1次元データに変えて全結合層へ
24 x = f.relu(self.fc1(x))
25 x = f.relu(self.fc2(x))
26 x = self.fc3(x)
27
28 return x
ネットワーク名は MyCNN としました。
通常の画像を畳み込む場合、torch.nn.Conv2d を用いますが引数についてはコメントのとおりです。
構築の仕方は「MNIST」の時とほとんど同じなので分かりやすいかと思います。
このとき、(入力チャネル)×(出力チャネル)が畳み込みフィルタの数になり、これらはネットワークが構築された段階でランダムに初期化されます。
この畳み込みフィルタは、『学習の過程でどう変化していくのか』を観察する予定です。
畳み込みフィルタを可視化する関数をつくる
では最初に、可視化用の関数をつくっていきましょう!
先ほどの MyCNN クラスのメンバ関数でOKですので、以下の関数を加筆してください。
1 def plot_conv1(self, prefix_num=0):
2 weights1 = self.conv1.weight
3 weights1 = weights1.reshape(3*6, 5, 5)
4
5 for i, weight in enumerate(weights1):
6 plt.subplot(3, 6, i + 1)
7 plt.imshow(weight.data.numpy(), cmap='winter')
8 plt.tick_params(labelbottom=False,
9 labelleft=False,
10 labelright=False,
11 labeltop=False,
12 bottom=False,
13 left=False,
14 right=False,
15 top=False)
16
17 plt.savefig('img/{}_conv1.png'.format(prefix_num))
18 plt.close()
19
20 def plot_conv2(self, prefix_num=0):
21 weights2 = self.conv2.weight
22 weights2 = weights2.reshape(6*16, 5, 5)
23
24 for i, weight in enumerate(weights2):
25 plt.subplot(6, 16, i + 1)
26 plt.imshow(weight.data.numpy(), cmap='winter')
27 plt.tick_params(labelbottom=False,
28 labelleft=False,
29 labelright=False,
30 labeltop=False,
31 bottom=False,
32 left=False,
33 right=False,
34 top=False)
35
36 plt.savefig('img/{}_conv2.png'.format(prefix_num))
37 plt.close()
各レイヤーの重み情報は weight というメンバ変数が保持しています。
これは単純な重み情報だけでなく、勾配情報やデバイス情報(CPU or GPU)などを保持しているので、純粋な重みを取り出す場合 weight.data とします。
あとは、先ほどの「CIFAR10」の可視化と、ほとんど一緒ですね。
ちなみに、学習前の重みはこんな感じです。
ただ、ランダムなので、まだ何がなんだかよくわかりませんね(笑)
学習部を作る
それでは、学習部を実装していきます。
これも「MNIST」の時とほとんど同様ですが、今回は損失関数として「クロスエントロピー」を用います。
メイン処理部分の、ネットワーク構築から訓練までは以下のようにしました。
1if __name__ == '__main__':
2 epoch = 50
3
4 loader = load_cifar10()
5 classes = ('plane', 'car', 'bird', 'cat', 'deer',
6 'dog', 'frog', 'horse', 'ship', 'truck')
7
8 net: MyCNN = MyCNN()
9 criterion = torch.nn.CrossEntropyLoss() # ロスの計算
10 optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
11
12 # 学習前のフィルタの可視化
13 net.plot_conv1()
14 net.plot_conv2()
15
16 history = {
17 'train_loss': [],
18 'train_acc': [],
19 'test_acc': []
20 }
21
22 for e in range(epoch):
23 net.train()
24 loss = None
25 for i, (images, labels) in enumerate(loader['train']):
26 optimizer.zero_grad()
27 output = net(images)
28 loss = criterion(output, labels)
29 loss.backward()
30 optimizer.step()
31
32 if i % 10 == 0:
33 print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1,
34 (i + 1) * 128,
35 loss.item())
36 )
37
38 # 学習過程でのフィルタの可視化
39 net.plot_conv1(e+1)
40 net.plot_conv2(e+1)
41
42 history['train_loss'].append(loss.item())
ひとまず、学習は50エポック分行いたいと思います。
最終的な結果として、訓練ロスと訓練精度、そしてテスト精度の変化が得られるようにします。
それらは history という名の、辞書型変数に格納する形になっています。
テスト部分
次にテスト部分ですが、これも「MNIST」の時と大差はありません。
今回は、訓練精度とテスト精度を見たいので、2つのループがあります。
実装
テスト部分から最後の結果を描画するまでは、以下のように実装しました。
1 net.eval()
2 correct = 0
3 with torch.no_grad():
4 for i, (images, labels) in enumerate(tqdm(loader['train'])):
5 outputs = net(images)
6 _, predicted = torch.max(outputs.data, 1)
7 correct += (predicted == labels).sum().item()
8
9 acc = float(correct / 50000)
10 history['train_acc'].append(acc)
11
12 correct = 0
13 with torch.no_grad():
14 for i, (images, labels) in enumerate(tqdm(loader['test'])):
15 outputs = net(images)
16 _, predicted = torch.max(outputs.data, 1)
17 correct += (predicted == labels).sum().item()
18
19 acc = float(correct / 10000)
20 history['test_acc'].append(acc)
21
22 # 結果をプロット
23 plt.plot(range(1, epoch+1), history['train_loss'])
24 plt.title('Training Loss [CIFAR10]')
25 plt.xlabel('epoch')
26 plt.ylabel('loss')
27 plt.savefig('img/cifar10_loss.png')
28 plt.close()
29
30 plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc')
31 plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc')
32 plt.title('Accuracies [CIFAR10]')
33 plt.xlabel('epoch')
34 plt.ylabel('accuracy')
35 plt.legend()
36 plt.savefig('img/cifar10_acc.png')
37 plt.close()
それでは、早速学習させてみます!
結果が気になるところですが…今回はここまで!
次回は、実際に学習させてみたり、GPUを使ってみたりしたいと思いますのでお楽しみに!
後編はこちら
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
関連記事
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit