1. HOME
  2. ブログ
  3. IT技術
  4. 【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる

【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる

PyTorchで手書き数字(MNIST)を学習させる

前回は、PyTorch(パイトーチ)のインストールなどを行いました。

今回は、いよいよPyTorchで手書き数字(MNIST)データセットを学習させていきたいと思います!

前回の記事はこちら

早速実装してみる

それでは、実装を始めます。

入門編ということで、一つ一つ丁寧に解説していきます!

必要なモジュールのインポート

今回使うモジュールを先に公開しておきます。

今回は、これだけ使用します。

(つまり最低でも、前回インストールした2つのモジュールがあればOKです)

ネットワークの構築

まずは、今回使うネットワークを定義していきます。

PyTorchでは、 torch.nn.Module というクラスを継承して、オリジナルのネットワークを構築していきます。

今回は MyNet という名前でネットワークを作っていきますが、ネットワーク構成はシンプルに「入力層(784) - 中間層(1000) - 出力層(10)」の3層構造とします。

  1. 中間層の活性化関数に「シグモイド(sigmoid)関数」
  2. 出力は確率にしたいので「ソフトマックス(softmax)関数」

中間層の活性化関数に「シグモイド( sigmoid )関数」を、出力は確率にしたいので「ソフトマックス( softmax )関数」を使用します。

今回は、MNISTという簡単なタスクで、なおかつ畳み込み層はないので、よく使用される ReLU関数 は使いません。(もちろん使っても良いです笑)

PyTorchでは、以上のようなネットワークの場合以下のように定義していきます。

とてもシンプルです。

最低限、コンストラクタ( def __init__() )と順伝播の関数( def forward() )を定義すればOKです。

データセット(MNIST)のロード

MNISTをロードする関数を作りましょう。

PyTorchでは、TorchVisionというモジュールでデータセットを管理しています。

まずは、出来上がった関数を見てみましょう。

PyTorchでは、データローダーという形でデータを取り扱うことが大きな特徴の一つです。

このデータローダーには、バッチサイズごとにまとめられたデータとラベルがまとまっています

さらにデータは、 torch.tensor というテンソルの形で扱いますが、データローダーにおけるデータの形は(batch, channel, dimension)という順番になっています。

これは後で、実際に見てみましょう。

また、 torch.utils.data.DataLoader() では、第一引数に「データセット」を取ります。

今回は、その第一引数に datasets.MNIST() というMNISTのデータを扱うためのクラスインスタンスが与えられていることが分かります。

このクラス( datasets.MNIST())では、コンストラクタとして第一引数にデータのダウンロード先を指定し、そのほかに訓練データか否か( train=True なら訓練データ、 train=False ならテストデータ)を指定したり、 transform= でデータを正規化したりできます。

今回は、画素値の最大値を intensity 倍するような形ですが、他によく見る形として、

のように、平均と分散を指定すると良い精度になる場合もあります。

今回用意した関数では、戻り値として各ローダーを辞書型変数にして返しています。

メイン処理部分を書く

下準備完了です!

早速学習させる部分を実装していきます。

まずは、ネットワークを構築して、データを取得するまでを示します。

これだけでOKです。

最適化

次は、学習率にどのような最適化を適用するかを決めます。

今回は、Adam という最適化手法を使ってみましょう。

初期学習率は、0.001 としました。

学習部分の実装

では、核となる学習部分の実装に移ります。

大枠としては、下記のように学習回数のループの中に、訓練データのループとテスト(検証)データのループを作ります。

それでは、早速中身を書いていきましょう。

訓練部分の実装

訓練部分は、以下のようにコーディングしてみました。

ここで実際にデータの形を出力してみると、先ほど話をしたように(batch, channel, dimension)になっていることがわかります。

ネットワークにデータを入力して出力を得るまでは、 output = net(data) だけで済むのは簡単ですね!

今回入力は、1次元でグレースケールなので、 data = data.view(-1, 28*28) で形を調整します。

そのあとは、ロスを計算( loss = f.nll_loss(output, target) )して、そのロスを元に誤差を逆伝播( loss.backward() )しているだけです。

ログは、10batch 毎に出力するようにしてみました。

テスト部分の作成

これで訓練部分は完成したので、テスト部分を作ります。

これも先ほどのコードと似ている部分がたくさんありますね。

テスト部分のロスは全て足して、最後に平均を取ることで、その学習(epoch)でのロスとしています。

また、テスト部分では精度も測りたいので、softmaxの確率出力の中で一番大きいニューロンのインデックスを取得しています ( pred = output.argmax(dim=1, keepdim=True) )。

このあと、ラベルと比較して一致しているものを正解数として記録しています。

完成!最終的なコード

最後に、結果を描画する部分を加筆して完成です!

ちなみに今回書いたコードは、「これが正解・最適」というわけではなく、筆者の好みも現れていますので、適宜自分の理解しやすいようにコーディングしてください!

動作確認

実際に動かしてみると、学習後に以下のような図が得られるはずです!

PyTorchでMNIST_Loss

PyTorchでMNIST_Accuracy

訓練ロスが若干バタついていますが、テスト精度は98%以上と、しっかり学習できていそうですね!

さいごに

今回は、PyTorchの入門編という立ち位置で「MNISTを単純なネットワークで学習」させてみました。

実際にコードを見てみると、機械学習初心者でも比較的馴染みやすい書き方だと思います。

これから機械学習を始める方、新しい機械学習ライブラリを探していた方、是非一度触ってみてください!

最初に言ったように、おそらく、これからどんどんホットになっていく機械学習ライブラリがPyTorchです!

こちらの記事もオススメ!

書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア