1. HOME
  2. ブログ
  3. IT技術
  4. 【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

【後編】PyTorchでCIFAR-10をCNNに学習させる

【前編】の続きとなります。

引き続き、PyTorch(パイトーチ)で畳み込みニューラルネットワーク(CNN)を実装していきたいと思います。

今回は、学習結果からとなります!

前編の記事はこちら

こちらの記事もオススメ!

学習結果

学習が終わりましたが、やはり「MNIST」と違って学習に時間がかかりますね!

ですが、50エポックなので数十分で終わるかと思います。(マシンスペックに依存しますが…)

訓練ロスと訓練 / テスト精度

学習によって得られた、『訓練ロス』『訓練 / テスト精度』から見てみましょう。

cifar10_loss

学習結果 (訓練Loss)

cifar10_acc

学習結果 (精度)

まだ、精度は「70%程度」と低いですが、しっかり学習できていそうですね!

畳み込み層のフィルタ

ちなみに、学習後の畳み込み層のフィルタを見てみると

50_conv1

学習後 のconv.1の重み (50 epoch)

50_conv2

学習後のconv.2の重み (50 epoch)

学習前と比べて、何かしらフィルタに模様が見えてきましたね。

よく見ると、斜め方向に対応するフィルタや、横方向に対応するフィルタが見受けられますが、まだはっきりとは分かりませんね。

またグラフを見ると、学習回数を増やせば、まだ精度は伸びそうな雰囲気があります。

「もっともっと学習を増やしてみましょう!」

…と言いたいところですが、そうなると学習に膨大な時間がかかってしまいそうです。

GPUを使ってみる

CUDA(Compute Unified Device Architecture:クーダ)」が使用可能であれば、PyTorchでは、簡単にGPUに演算を行わせることができます。

メイン処理部の冒頭に、以下を加筆して下さい。

torch.cuda.is_available() は、「CUDA」が使用可能ならTrueを返すといった、GPU を使用できるかを簡単に確認できる関数です。

その後、ネットワークを to() でデバイスに投げるだけです。

しかし、これだけではエラーを吐かれてしまいます。

GPU に使用するデータセットを投げる

使用するデータセットも、GPU に投げる必要があるので、各データローダーのループに、以下のように加筆して下さい。

matplotlib で重みを描画するには

最後に、重みを描画する「matplotlib」では、CPU で描画するため、その際は逆に CPU に投げる必要があります

以下のように書けば OK です!

これで、GPU を使用する準備が整いました!

早速学習させてみましょう!

ちなみに筆者の環境は、GPU は「NVIDIA GeForce GTX Titan Blak 6GB」で、「CUDA」はバーション9.2で動作確認をしています。

学習結果 (300エポック)

さすがGPUを使うと、目に見えて学習が早いです!

時間はしっかりと測定していませんが、1.5 ~ 2倍くらい早いです。

では早速、300エポックの学習結果を見てみましょう!

cifar10_acc

300エポックの学習結果…?

なんと、訓練精度は「約100%」になりましたが、テスト精度は「60%程度」になってしまいました

「過学習」が起こってしまいました!

【過学習とは?】
過学習とは、このように訓練データに適応しすぎて、テストデータなどに対する性能、いわゆる汎化性能が低下してしまう現象を言います。

Dropoutを導入してみる

それでは、過学習対策としてメジャーな「Dropout」を導入してみましょう。

「Dropout」とは、簡単に言えば、学習時に一部のニューロンを、わざと非活性化させ、訓練データに適合しすぎないようにする手法です。

追加してみました!

さて、結果は、どうなるでしょうか!?

実験結果 (300エポック + Dropout)

cifar10_acc

学習結果

過学習はなくなりましたね!

ただ、やはり精度は「70%程度」といったところでしょうか。

もしかしたら、これが「LeNet」の限界なのかもしれません…。

他にも「Batch Normalization(バッチ正規化)」や、「Data Augmentation(データ拡張)」などの手法を用いれば、過学習を抑制しつつ精度向上が見込めるかもしれません。

ですが、本記事ではここまでとします。

ここまできたら、もう少し深い層の、畳み込みニューラルネットワーク (DCNN: Deep Convolutional Neural Networks)を構築した方が良いでしょう。

フィルタの重みの学習結果

ちなみに、フィルタの重みの学習結果も載せておきますが、50エポックの時と、大した差はありませんね。

やや、模様が明確になったような気もします(笑)

300_conv1

Conv.1のフィルタ

300_conv2

Conv.2のフィルタ

さいごに

長丁場になりましたが、今回は、PyTorchで畳み込みニューラルネットワークを構築し、カラー画像の「CIFAR10」を学習させてみました。

また、ネットワークの内部を観察したり、いろいろな考察もしてみたので、機械学習初学者の皆さまの参考になれば幸いです。

次回は、もう少し複雑なネットワークで試してみようと考えているのでお楽しみに!

ソースコード

前編の記事はこちら

こちらの記事もオススメ!

関連記事

ライトコードよりお知らせ

にゃんこ師匠にゃんこ師匠
システム開発のご相談やご依頼はこちら
ミツオカミツオカ
ライトコードの採用募集はこちら
にゃんこ師匠にゃんこ師匠
社長と一杯飲みながらお話してみたい方はこちら
ミツオカミツオカ
フリーランスエンジニア様の募集はこちら
にゃんこ師匠にゃんこ師匠
その他、お問い合わせはこちら
ミツオカミツオカ
   
お気軽にお問い合わせください!せっかくなので、別の記事もぜひ読んでいって下さいね!

一緒に働いてくれる仲間を募集しております!

ライトコードでは、仲間を募集しております!

当社のモットーは「好きなことを仕事にするエンジニア集団」「エンジニアによるエンジニアのための会社」。エンジニアであるあなたの「やってみたいこと」を全力で応援する会社です。

また、ライトコードは現在、急成長中!だからこそ、あなたにお任せしたいやりがいのあるお仕事は沢山あります。「コアメンバー」として活躍してくれる、あなたからのご応募をお待ちしております!

なお、ご応募の前に、「話しだけ聞いてみたい」「社内の雰囲気を知りたい」という方はこちらをご覧ください。

ライトコードでは一緒に働いていただける方を募集しております!

採用情報はこちら

書いた人はこんな人

ライトコードメディア編集部
ライトコードメディア編集部
「好きなことを仕事にするエンジニア集団」の(株)ライトコードのメディア編集部が書いている記事です。

関連記事