1. HOME
  2. ブログ
  3. IT技術
  4. 【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

PyTorchでCIFAR-10をCNNに学習させる

前回の『【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる』に引き続き、PyTorchで機械学習を学んでいきましょう!

今回は、PyTorchで畳み込みニューラルネットワーク(CNN)を実装していきます。

ちなみに、公式ドキュメントにも同じような実装が紹介されているようです。

ですが、本記事では、日本語で分かりやすく詳細に解説していきたいと思っています。

さらに最後には、ネットワークの内部を可視化してみたり、GPUを使用してみたりと、様々な実験が含まれている記事になっています!

ですので、ぜひ最後まで読んでみてください!

インポートされているモジュール

これから実装するコードは、以下のモジュールがあらかじめインポートされています。

それでは、実際に実装していきましょう!

こちらの記事もオススメ!

CIFAR10の準備

今回用いるデータセットの「CIFAR10 (サイファー10)」 は、32×32 のカラー画像からなるデータセットで、その名の通り10クラスあります。

「MNIST」は 28×28 のグレースケール画像なので、「CIFER10」の方が情報量は数倍多く、学習は難しいです。

畳み込みニューラルネットワークを学んだり、研究しているほとんどの人は、このデータセットをカラー画像の簡易なベンチマークとして使用しています。

それでは、このデータセットを読み込む関数を作っていきます。

CIFER10を読み込む関数を作る

「MNIST」で行った時とほとんど変わりませんが、今回データの正規化として、各カラーチャネルの平均と標準偏差を「0.5」になるようにしています。

画像データセットでよくある「正規化」ですので、覚えておきましょう!

動作確認

それでは、試しに動作確認をしてみましょう!

画像の描画はイテレーションループを使ってますが、もちろん iter() でもOKです!

PyTorch のローダーを使って取得したデータセットは、イテレータで取得したときに [バッチサイズ, チャネル, (画像のシェイプ)]というテンソル型の画像と、[バッチサイズ] というテンソル型のラベルを返します。

描画結果

実際に50枚描画してみると、以下のようになりました。

CIFAR10の例

32×32 なので粗い画像ですが、画像とラベルが一致していそうですね。

これらの画像を、今から畳み込みニューラルネットワーク(CNN)に学習させていきます!

CNNの構築

それでは、早速ネットワークを構築していきます。

今回は以下のような、「LeNet」 と呼ばれる畳み込みニューラルネットワーク(CNN)をベースに構築し、学習させていきます。

「LeNet」の構成

「LeNet」が提案されたのは1998年と古いものですが、畳み込みニューラルネットワーク(CNN)という名を有名にさせたネットワークです。

LeNetの構成

このネットワークを、ほとんどそのまま実装してみると以下のようになります。

(実際には活性化関数など、一部元論文と異なります)

実装

ネットワーク名は MyCNN としました。

通常の画像を畳み込む場合、 torch.nn.Conv2d を用いますが引数についてはコメントのとおりです。

構築の仕方は「MNIST」の時とほとんど同じなので分かりやすいかと思います。

このとき、(入力チャネル)×(出力チャネル)が畳み込みフィルタの数になり、これらはネットワークが構築された段階でランダムに初期化されます。

この畳み込みフィルタは、『学習の過程でどう変化していくのか』を観察する予定です。

畳み込みフィルタを可視化する関数をつくる

では最初に、可視化用の関数をつくっていきましょう!

先ほどの MyCNN クラスのメンバ関数でOKですので、以下の関数を加筆してください。

各レイヤーの重み情報は weight というメンバ変数が保持しています。

これは単純な重み情報だけでなく、勾配情報やデバイス情報(CPU or GPU)などを保持しているので、純粋な重みを取り出す場合 weight.data とします。

あとは、先ほどの「CIFAR10」の可視化と、ほとんど一緒ですね。

ちなみに、学習前の重みはこんな感じです。

学習前のconv1

学習前のconv2

ただ、ランダムなので、まだ何がなんだかよくわかりませんね(笑)

学習部を作る

それでは、学習部を実装していきます。

これも「MNIST」の時とほとんど同様ですが、今回は損失関数として「クロスエントロピー」を用います。

メイン処理部分の、ネットワーク構築から訓練までは以下のようにしました。

ひとまず、学習は50エポック分行いたいと思います。

最終的な結果として、訓練ロスと訓練精度、そしてテスト精度の変化が得られるようにします。

それらは history という名の、辞書型変数に格納する形になっています。

テスト部分

次にテスト部分ですが、これも「MNIST」の時と大差はありません。

今回は、訓練精度とテスト精度を見たいので、2つのループがあります。

実装

テスト部分から最後の結果を描画するまでは、以下のように実装しました。

それでは、早速学習させてみます!

結果が気になるところですが…今回はここまで!

次回は、実際に学習させてみたり、GPUを使ってみたりしたいと思いますのでお楽しみに!

後編はこちら

こちらの記事もオススメ!

関連記事

書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア