
【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
2020.08.19
PyTorchでCIFAR-10をCNNに学習させる
前回の『【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる』に引き続き、PyTorchで機械学習を学んでいきましょう!
今回は、PyTorchで畳み込みニューラルネットワーク(CNN)を実装していきます。
ちなみに、公式ドキュメントにも同じような実装が紹介されているようです。
ですが、本記事では、日本語で分かりやすく詳細に解説していきたいと思っています。
さらに最後には、ネットワークの内部を可視化してみたり、GPUを使用してみたりと、様々な実験が含まれている記事になっています!
ですので、ぜひ最後まで読んでみてください!
インポートされているモジュール
これから実装するコードは、以下のモジュールがあらかじめインポートされています。
1 2 3 4 5 6 | import torch import torch.nn.functional as f from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt from tqdm import tqdm |
それでは、実際に実装していきましょう!
こちらの記事もオススメ!
CIFAR10の準備
今回用いるデータセットの「CIFAR10 (サイファー10)」 は、32×32 のカラー画像からなるデータセットで、その名の通り10クラスあります。
「MNIST」は 28×28 のグレースケール画像なので、「CIFER10」の方が情報量は数倍多く、学習は難しいです。
畳み込みニューラルネットワークを学んだり、研究しているほとんどの人は、このデータセットをカラー画像の簡易なベンチマークとして使用しています。
それでは、このデータセットを読み込む関数を作っていきます。
CIFER10を読み込む関数を作る
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | def load_cifar10(batch=128): train_loader = DataLoader( datasets.CIFAR10('./data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( [0.5, 0.5, 0.5], # RGB 平均 [0.5, 0.5, 0.5] # RGB 標準偏差 ) ])), batch_size=batch, shuffle=True ) test_loader = DataLoader( datasets.CIFAR10('./data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( [0.5, 0.5, 0.5], # RGB 平均 [0.5, 0.5, 0.5] # RGB 標準偏差 ) ])), batch_size=batch, shuffle=True ) return {'train': train_loader, 'test': test_loader} |
「MNIST」で行った時とほとんど変わりませんが、今回データの正規化として、各カラーチャネルの平均と標準偏差を「0.5」になるようにしています。
画像データセットでよくある「正規化」ですので、覚えておきましょう!
動作確認
それでは、試しに動作確認をしてみましょう!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | if __name__ == '__main__': loader = load_cifar10() classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # CIFAR10のクラス for images, labels in loader['train']: print(images.shape) # torch.Size([128, 3, 32, 32]) # 試しに50枚を 5x10 で見てみる for i in range(5): for j in range(10): image = images[i*10+j] / 2 + 0.5 image = image.numpy() plt.subplot(5, 10, i*10+j + 1) plt.imshow(np.transpose(image, (1, 2, 0))) # matplotlibではチャネルは第3次元 # 対応するラベル plt.title(classes[int(labels[i*10+j])]) # 軸目盛や値はいらないので消す plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False, bottom=False, left=False, right=False, top=False) plt.show() break |
画像の描画はイテレーションループを使ってますが、もちろん iter() でもOKです!
PyTorch のローダーを使って取得したデータセットは、イテレータで取得したときに [バッチサイズ, チャネル, (画像のシェイプ)]というテンソル型の画像と、[バッチサイズ] というテンソル型のラベルを返します。
描画結果
実際に50枚描画してみると、以下のようになりました。

CIFAR10の例
32×32 なので粗い画像ですが、画像とラベルが一致していそうですね。
これらの画像を、今から畳み込みニューラルネットワーク(CNN)に学習させていきます!
CNNの構築
それでは、早速ネットワークを構築していきます。
今回は以下のような、「LeNet」 と呼ばれる畳み込みニューラルネットワーク(CNN)をベースに構築し、学習させていきます。
「LeNet」の構成
「LeNet」が提案されたのは1998年と古いものですが、畳み込みニューラルネットワーク(CNN)という名を有名にさせたネットワークです。

LeNetの構成
このネットワークを、ほとんどそのまま実装してみると以下のようになります。
(実際には活性化関数など、一部元論文と異なります)
実装
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | class MyCNN(torch.nn.Module): def __init__(self): super(MyCNN, self).__init__() self.conv1 = torch.nn.Conv2d(3, # チャネル入力 6, # チャンネル出力 5, # カーネルサイズ 1, # ストライド (デフォルトは1) 0, # パディング (デフォルトは0) ) self.conv2 = torch.nn.Conv2d(6, 16, 5) self.pool = torch.nn.MaxPool2d(2, 2) # カーネルサイズ, ストライド self.fc1 = torch.nn.Linear(16 * 5 * 5, 120) # 入力サイズ, 出力サイズ self.fc2 = torch.nn.Linear(120, 84) self.fc3 = torch.nn.Linear(84, 10) def forward(self, x): x = f.relu(self.conv1(x)) x = self.pool(x) x = f.relu(self.conv2(x)) x = self.pool(x) x = x.view(-1, 16 * 5 * 5) # 1次元データに変えて全結合層へ x = f.relu(self.fc1(x)) x = f.relu(self.fc2(x)) x = self.fc3(x) return x |
ネットワーク名は MyCNN としました。
通常の画像を畳み込む場合、 torch.nn.Conv2d を用いますが引数についてはコメントのとおりです。
構築の仕方は「MNIST」の時とほとんど同じなので分かりやすいかと思います。
このとき、(入力チャネル)×(出力チャネル)が畳み込みフィルタの数になり、これらはネットワークが構築された段階でランダムに初期化されます。
この畳み込みフィルタは、『学習の過程でどう変化していくのか』を観察する予定です。
畳み込みフィルタを可視化する関数をつくる
では最初に、可視化用の関数をつくっていきましょう!
先ほどの MyCNN クラスのメンバ関数でOKですので、以下の関数を加筆してください。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | def plot_conv1(self, prefix_num=0): weights1 = self.conv1.weight weights1 = weights1.reshape(3*6, 5, 5) for i, weight in enumerate(weights1): plt.subplot(3, 6, i + 1) plt.imshow(weight.data.numpy(), cmap='winter') plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False, bottom=False, left=False, right=False, top=False) plt.savefig('img/{}_conv1.png'.format(prefix_num)) plt.close() def plot_conv2(self, prefix_num=0): weights2 = self.conv2.weight weights2 = weights2.reshape(6*16, 5, 5) for i, weight in enumerate(weights2): plt.subplot(6, 16, i + 1) plt.imshow(weight.data.numpy(), cmap='winter') plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False, bottom=False, left=False, right=False, top=False) plt.savefig('img/{}_conv2.png'.format(prefix_num)) plt.close() |
各レイヤーの重み情報は weight というメンバ変数が保持しています。
これは単純な重み情報だけでなく、勾配情報やデバイス情報(CPU or GPU)などを保持しているので、純粋な重みを取り出す場合 weight.data とします。
あとは、先ほどの「CIFAR10」の可視化と、ほとんど一緒ですね。
ちなみに、学習前の重みはこんな感じです。

学習前のconv1

学習前のconv2
ただ、ランダムなので、まだ何がなんだかよくわかりませんね(笑)
学習部を作る
それでは、学習部を実装していきます。
これも「MNIST」の時とほとんど同様ですが、今回は損失関数として「クロスエントロピー」を用います。
メイン処理部分の、ネットワーク構築から訓練までは以下のようにしました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | if __name__ == '__main__': epoch = 50 loader = load_cifar10() classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') net: MyCNN = MyCNN() criterion = torch.nn.CrossEntropyLoss() # ロスの計算 optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9) # 学習前のフィルタの可視化 net.plot_conv1() net.plot_conv2() history = { 'train_loss': [], 'train_acc': [], 'test_acc': [] } for e in range(epoch): net.train() loss = None for i, (images, labels) in enumerate(loader['train']): optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss.backward() optimizer.step() if i % 10 == 0: print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1, (i + 1) * 128, loss.item()) ) # 学習過程でのフィルタの可視化 net.plot_conv1(e+1) net.plot_conv2(e+1) history['train_loss'].append(loss.item()) |
ひとまず、学習は50エポック分行いたいと思います。
最終的な結果として、訓練ロスと訓練精度、そしてテスト精度の変化が得られるようにします。
それらは history という名の、辞書型変数に格納する形になっています。
テスト部分
次にテスト部分ですが、これも「MNIST」の時と大差はありません。
今回は、訓練精度とテスト精度を見たいので、2つのループがあります。
実装
テスト部分から最後の結果を描画するまでは、以下のように実装しました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | net.eval() correct = 0 with torch.no_grad(): for i, (images, labels) in enumerate(tqdm(loader['train'])): outputs = net(images) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() acc = float(correct / 50000) history['train_acc'].append(acc) correct = 0 with torch.no_grad(): for i, (images, labels) in enumerate(tqdm(loader['test'])): outputs = net(images) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() acc = float(correct / 10000) history['test_acc'].append(acc) # 結果をプロット plt.plot(range(1, epoch+1), history['train_loss']) plt.title('Training Loss [CIFAR10]') plt.xlabel('epoch') plt.ylabel('loss') plt.savefig('img/cifar10_loss.png') plt.close() plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc') plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc') plt.title('Accuracies [CIFAR10]') plt.xlabel('epoch') plt.ylabel('accuracy') plt.legend() plt.savefig('img/cifar10_acc.png') plt.close() |
それでは、早速学習させてみます!
結果が気になるところですが…今回はここまで!
次回は、実際に学習させてみたり、GPUを使ってみたりしたいと思いますのでお楽しみに!
後編はこちら
こちらの記事もオススメ!
関連記事
書いた人はこんな人

- 「好きを仕事にするエンジニア集団」の(株)ライトコードです!
ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。
システム開発依頼・お見積もり大歓迎!
また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。
以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit
ITエンタメ10月 13, 2023Netflixの成功はレコメンドエンジン?
ライトコードの日常8月 30, 2023退職者の最終出社日に密着してみた!
ITエンタメ8月 3, 2023世界初の量産型ポータブルコンピュータを開発したのに倒産!?アダム・オズボーン
ITエンタメ7月 14, 2023【クリス・ワンストラス】GitHubが出来るまでとソフトウェアの未来