1. HOME
  2. ブログ
  3. IT技術
  4. 【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる

【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる

PyTorchで手書き数字(MNIST)を学習させる

前回は、PyTorch(パイトーチ)のインストールなどを行いました。

今回は、いよいよPyTorchで手書き数字(MNIST)データセットを学習させていきたいと思います!

前回の記事はこちら

早速実装してみる

それでは、実装を始めます。

入門編ということで、一つ一つ丁寧に解説していきます!

必要なモジュールのインポート

今回使うモジュールを先に公開しておきます。

今回は、これだけ使用します。

(つまり最低でも、前回インストールした2つのモジュールがあればOKです)

ネットワークの構築

まずは、今回使うネットワークを定義していきます。

PyTorchでは、 torch.nn.Module というクラスを継承して、オリジナルのネットワークを構築していきます。

今回は MyNet という名前でネットワークを作っていきますが、ネットワーク構成はシンプルに「入力層(784) – 中間層(1000) – 出力層(10)」の3層構造とします。

  1. 中間層の活性化関数に「シグモイド(sigmoid)関数」
  2. 出力は確率にしたいので「ソフトマックス(softmax)関数」

中間層の活性化関数に「シグモイド( sigmoid )関数」を、出力は確率にしたいので「ソフトマックス( softmax )関数」を使用します。

今回は、MNISTという簡単なタスクで、なおかつ畳み込み層はないので、よく使用される ReLU関数 は使いません。(もちろん使っても良いです笑)

PyTorchでは、以上のようなネットワークの場合以下のように定義していきます。

とてもシンプルです。

最低限、コンストラクタ( def __init__() )と順伝播の関数( def forward() )を定義すればOKです。

データセット(MNIST)のロード

MNISTをロードする関数を作りましょう。

PyTorchでは、TorchVisionというモジュールでデータセットを管理しています。

まずは、出来上がった関数を見てみましょう。

PyTorchでは、データローダーという形でデータを取り扱うことが大きな特徴の一つです。

このデータローダーには、バッチサイズごとにまとめられたデータとラベルがまとまっています

さらにデータは、 torch.tensor というテンソルの形で扱いますが、データローダーにおけるデータの形は(batch, channel, dimension)という順番になっています。

これは後で、実際に見てみましょう。

また、 torch.utils.data.DataLoader() では、第一引数に「データセット」を取ります。

今回は、その第一引数に datasets.MNIST() というMNISTのデータを扱うためのクラスインスタンスが与えられていることが分かります。

このクラス( datasets.MNIST())では、コンストラクタとして第一引数にデータのダウンロード先を指定し、そのほかに訓練データか否か( train=True なら訓練データ、 train=False ならテストデータ)を指定したり、 transform= でデータを正規化したりできます。

今回は、画素値の最大値を intensity 倍するような形ですが、他によく見る形として、

のように、平均と分散を指定すると良い精度になる場合もあります。

今回用意した関数では、戻り値として各ローダーを辞書型変数にして返しています。

メイン処理部分を書く

下準備完了です!

早速学習させる部分を実装していきます。

まずは、ネットワークを構築して、データを取得するまでを示します。

これだけでOKです。

最適化

次は、学習率にどのような最適化を適用するかを決めます。

今回は、Adam という最適化手法を使ってみましょう。

初期学習率は、0.001 としました。

学習部分の実装

では、核となる学習部分の実装に移ります。

大枠としては、下記のように学習回数のループの中に、訓練データのループとテスト(検証)データのループを作ります。

それでは、早速中身を書いていきましょう。

訓練部分の実装

訓練部分は、以下のようにコーディングしてみました。

ここで実際にデータの形を出力してみると、先ほど話をしたように(batch, channel, dimension)になっていることがわかります。

ネットワークにデータを入力して出力を得るまでは、 output = net(data) だけで済むのは簡単ですね!

今回入力は、1次元でグレースケールなので、 data = data.view(-1, 28*28) で形を調整します。

そのあとは、ロスを計算( loss = f.nll_loss(output, target) )して、そのロスを元に誤差を逆伝播( loss.backward() )しているだけです。

ログは、10batch 毎に出力するようにしてみました。

テスト部分の作成

これで訓練部分は完成したので、テスト部分を作ります。

これも先ほどのコードと似ている部分がたくさんありますね。

テスト部分のロスは全て足して、最後に平均を取ることで、その学習(epoch)でのロスとしています。

また、テスト部分では精度も測りたいので、softmaxの確率出力の中で一番大きいニューロンのインデックスを取得しています ( pred = output.argmax(dim=1, keepdim=True) )。

このあと、ラベルと比較して一致しているものを正解数として記録しています。

完成!最終的なコード

最後に、結果を描画する部分を加筆して完成です!

ちなみに今回書いたコードは、「これが正解・最適」というわけではなく、筆者の好みも現れていますので、適宜自分の理解しやすいようにコーディングしてください!

動作確認

実際に動かしてみると、学習後に以下のような図が得られるはずです!

PyTorchでMNIST_Loss

PyTorchでMNIST_Accuracy

訓練ロスが若干バタついていますが、テスト精度は98%以上と、しっかり学習できていそうですね!

さいごに

今回は、PyTorchの入門編という立ち位置で「MNISTを単純なネットワークで学習」させてみました。

実際にコードを見てみると、機械学習初心者でも比較的馴染みやすい書き方だと思います。

これから機械学習を始める方、新しい機械学習ライブラリを探していた方、是非一度触ってみてください!

最初に言ったように、おそらく、これからどんどんホットになっていく機械学習ライブラリがPyTorchです。

次回の記事では、畳み込みニューラルネットワークを使って、もう少し難しいタスクに挑戦してみますのでお楽しみに!

関連記事

ライトコードよりお知らせ

にゃんこ師匠にゃんこ師匠
システム開発のご相談やご依頼はこちら
ミツオカミツオカ
ライトコードの採用募集はこちら
にゃんこ師匠にゃんこ師匠
社長と一杯飲みながらお話してみたい方はこちら
ミツオカミツオカ
フリーランスエンジニア様の募集はこちら
にゃんこ師匠にゃんこ師匠
その他、お問い合わせはこちら
ミツオカミツオカ
   
お気軽にお問い合わせください!せっかくなので、別の記事もぜひ読んでいって下さいね!

一緒に働いてくれる仲間を募集しております!

ライトコードでは、仲間を募集しております!

当社のモットーは「好きなことを仕事にするエンジニア集団」「エンジニアによるエンジニアのための会社」。エンジニアであるあなたの「やってみたいこと」を全力で応援する会社です。

また、ライトコードは現在、急成長中!だからこそ、あなたにお任せしたいやりがいのあるお仕事は沢山あります。「コアメンバー」として活躍してくれる、あなたからのご応募をお待ちしております!

なお、ご応募の前に、「話しだけ聞いてみたい」「社内の雰囲気を知りたい」という方はこちらをご覧ください。

ライトコードでは一緒に働いていただける方を募集しております!

採用情報はこちら

書いた人はこんな人

ライトコードメディア編集部
ライトコードメディア編集部
「好きなことを仕事にするエンジニア集団」の(株)ライトコードのメディア編集部が書いている記事です。

関連記事