1. HOME
  2. ブログ
  3. IT技術
  4. 【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

【後編】PyTorchでCIFAR-10をCNNに学習させる

【前編】の続きとなります。

引き続き、PyTorch(パイトーチ)で畳み込みニューラルネットワーク(CNN)を実装していきたいと思います。

今回は、学習結果からとなります!

前編の記事はこちら

こちらの記事もオススメ!

学習結果

学習が終わりましたが、やはり「MNIST」と違って学習に時間がかかりますね!

ですが、50エポックなので数十分で終わるかと思います。(マシンスペックに依存しますが...)

訓練ロスと訓練 / テスト精度

学習によって得られた、『訓練ロス』『訓練 / テスト精度』から見てみましょう。

学習結果 (訓練Loss)

学習結果 (精度)

まだ、精度は「70%程度」と低いですが、しっかり学習できていそうですね!

畳み込み層のフィルタ

ちなみに、学習後の畳み込み層のフィルタを見てみると

学習後 のconv.1の重み (50 epoch)

学習後のconv.2の重み (50 epoch)

学習前と比べて、何かしらフィルタに模様が見えてきましたね。

よく見ると、斜め方向に対応するフィルタや、横方向に対応するフィルタが見受けられますが、まだはっきりとは分かりませんね。

またグラフを見ると、学習回数を増やせば、まだ精度は伸びそうな雰囲気があります。

「もっともっと学習を増やしてみましょう!」

…と言いたいところですが、そうなると学習に膨大な時間がかかってしまいそうです。

GPUを使ってみる

CUDA(Compute Unified Device Architecture:クーダ)」が使用可能であれば、PyTorchでは、簡単にGPUに演算を行わせることができます。

メイン処理部の冒頭に、以下を加筆して下さい。

torch.cuda.is_available() は、「CUDA」が使用可能ならTrueを返すといった、GPU を使用できるかを簡単に確認できる関数です。

その後、ネットワークを to() でデバイスに投げるだけです。

しかし、これだけではエラーを吐かれてしまいます。

GPU に使用するデータセットを投げる

使用するデータセットも、GPU に投げる必要があるので、各データローダーのループに、以下のように加筆して下さい。

matplotlib で重みを描画するには

最後に、重みを描画する「matplotlib」では、CPU で描画するため、その際は逆に CPU に投げる必要があります

以下のように書けば OK です!

これで、GPU を使用する準備が整いました!

早速学習させてみましょう!

ちなみに筆者の環境は、GPU は「NVIDIA GeForce GTX Titan Blak 6GB」で、「CUDA」はバーション9.2で動作確認をしています。

学習結果 (300エポック)

さすがGPUを使うと、目に見えて学習が早いです!

時間はしっかりと測定していませんが、1.5 ~ 2倍くらい早いです。

では早速、300エポックの学習結果を見てみましょう!

300エポックの学習結果...?

なんと、訓練精度は「約100%」になりましたが、テスト精度は「60%程度」になってしまいました

「過学習」が起こってしまいました!

【過学習とは?】
過学習とは、このように訓練データに適応しすぎて、テストデータなどに対する性能、いわゆる汎化性能が低下してしまう現象を言います。

Dropoutを導入してみる

それでは、過学習対策としてメジャーな「Dropout」を導入してみましょう。

「Dropout」とは、簡単に言えば、学習時に一部のニューロンを、わざと非活性化させ、訓練データに適合しすぎないようにする手法です。

追加してみました!

さて、結果は、どうなるでしょうか!?

実験結果 (300エポック + Dropout)

学習結果

過学習はなくなりましたね!

ただ、やはり精度は「70%程度」といったところでしょうか。

もしかしたら、これが「LeNet」の限界なのかもしれません...。

他にも「Batch Normalization(バッチ正規化)」や、「Data Augmentation(データ拡張)」などの手法を用いれば、過学習を抑制しつつ精度向上が見込めるかもしれません。

ですが、本記事ではここまでとします。

ここまできたら、もう少し深い層の、畳み込みニューラルネットワーク (DCNN: Deep Convolutional Neural Networks)を構築した方が良いでしょう。

フィルタの重みの学習結果

ちなみに、フィルタの重みの学習結果も載せておきますが、50エポックの時と、大した差はありませんね。

やや、模様が明確になったような気もします(笑)

Conv.1のフィルタ

Conv.2のフィルタ

さいごに

長丁場になりましたが、今回は、PyTorchで畳み込みニューラルネットワークを構築し、カラー画像の「CIFAR10」を学習させてみました。

また、ネットワークの内部を観察したり、いろいろな考察もしてみたので、機械学習初学者の皆さまの参考になれば幸いです。

次回は、もう少し複雑なネットワークで試してみようと考えているのでお楽しみに!

ソースコード

前編の記事はこちら

こちらの記事もオススメ!

関連記事

書いた人はこんな人

ライトコード社員ブログ
ライトコード社員ブログ
「好きなことを仕事にするエンジニア集団」の(株)ライトコードです!
ライトコードは、福岡本社、東京オフィスの2拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もりは大歓迎!
また、WEBエンジニアとモバイルエンジニアも積極採用中です!

ご応募をお待ちしております!

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア