【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる
IT技術
PyTorchで手書き数字(MNIST)を学習させる
前回は、PyTorch(パイトーチ)のインストールなどを行いました。
今回は、いよいよPyTorchで手書き数字(MNIST)データセットを学習させていきたいと思います!
前回の記事はこちら
2020.01.20PyTorchの特徴とインストール方法PyTorchとはPyTorch(パイトーチ)とは、Pythonの機械学習ライブラリの一つで、現在最もアツいフレームワ...
早速実装してみる
それでは、実装を始めます。
入門編ということで、一つ一つ丁寧に解説していきます!
必要なモジュールのインポート
今回使うモジュールを先に公開しておきます。
1import torch
2import torch.nn.functional as f
3from torch.utils.data import DataLoader
4from torchvision import datasets, transforms
5import matplotlib.pyplot as plt
今回は、これだけ使用します。
(つまり最低でも、前回インストールした2つのモジュールがあればOKです)
ネットワークの構築
まずは、今回使うネットワークを定義していきます。
PyTorchでは、torch.nn.Module というクラスを継承して、オリジナルのネットワークを構築していきます。
今回はMyNet という名前でネットワークを作っていきますが、ネットワーク構成はシンプルに「入力層(784) - 中間層(1000) - 出力層(10)」の3層構造とします。
- 中間層の活性化関数に「シグモイド(sigmoid)関数」
- 出力は確率にしたいので「ソフトマックス(softmax)関数」
中間層の活性化関数に「シグモイド( sigmoid )関数」を、出力は確率にしたいので「ソフトマックス( softmax )関数」を使用します。
今回は、MNISTという簡単なタスクで、なおかつ畳み込み層はないので、よく使用される ReLU関数 は使いません。(もちろん使っても良いです笑)
PyTorchでは、以上のようなネットワークの場合以下のように定義していきます。
1class MyNet(torch.nn.Module):
2 def __init__(self):
3 super(MyNet, self).__init__()
4 self.fc1 = torch.nn.Linear(28*28, 1000)
5 self.fc2 = torch.nn.Linear(1000, 10)
6
7 def forward(self, x):
8 x = self.fc1(x)
9 x = torch.sigmoid(x)
10 x = self.fc2(x)
11
12 return f.log_softmax(x, dim=1)
とてもシンプルです。
最低限、コンストラクタ(def __init__() )と順伝播の関数(def forward() )を定義すればOKです。
データセット(MNIST)のロード
MNISTをロードする関数を作りましょう。
PyTorchでは、TorchVisionというモジュールでデータセットを管理しています。
まずは、出来上がった関数を見てみましょう。
1def load_MNIST(batch=128, intensity=1.0):
2 train_loader = torch.utils.data.DataLoader(
3 datasets.MNIST('./data',
4 train=True,
5 download=True,
6 transform=transforms.Compose([
7 transforms.ToTensor(),
8 transforms.Lambda(lambda x: x * intensity)
9 ])),
10 batch_size=batch,
11 shuffle=True)
12
13 test_loader = torch.utils.data.DataLoader(
14 datasets.MNIST('./data',
15 train=False,
16 transform=transforms.Compose([
17 transforms.ToTensor(),
18 transforms.Lambda(lambda x: x * intensity)
19 ])),
20 batch_size=batch,
21 shuffle=True)
22
23 return {'train': train_loader, 'test': test_loader}
PyTorchでは、データローダーという形でデータを取り扱うことが大きな特徴の一つです。
このデータローダーには、バッチサイズごとにまとめられたデータとラベルがまとまっています。
さらにデータは、torch.tensor というテンソルの形で扱いますが、データローダーにおけるデータの形は(batch, channel, dimension)という順番になっています。
これは後で、実際に見てみましょう。
また、torch.utils.data.DataLoader() では、第一引数に「データセット」を取ります。
今回は、その第一引数にdatasets.MNIST() というMNISTのデータを扱うためのクラスインスタンスが与えられていることが分かります。
このクラス(datasets.MNIST())では、コンストラクタとして第一引数にデータのダウンロード先を指定し、そのほかに訓練データか否か(train=True なら訓練データ、train=False ならテストデータ)を指定したり、transform= でデータを正規化したりできます。
今回は、画素値の最大値をintensity 倍するような形ですが、他によく見る形として、
1 train_loader = torch.utils.data.DataLoader(
2 datasets.MNIST('./data',
3 train=True,
4 download=True,
5 transform=transforms.Compose([
6 transforms.ToTensor(),
7 transforms.Normalize((0.5,), (0.5,)) # ここが違う
8 ])),
9 batch_size=batch,
10 shuffle=True)
のように、平均と分散を指定すると良い精度になる場合もあります。
今回用意した関数では、戻り値として各ローダーを辞書型変数にして返しています。
メイン処理部分を書く
下準備完了です!
早速学習させる部分を実装していきます。
まずは、ネットワークを構築して、データを取得するまでを示します。
1if __name__ == '__main__':
2 # 学習回数
3 epoch = 20
4
5 # 学習結果の保存用
6 history = {
7 'train_loss': [],
8 'test_loss': [],
9 'test_acc': [],
10 }
11
12 # ネットワークを構築
13 net: torch.nn.Module = MyNet()
14
15 # MNISTのデータローダーを取得
16 loaders = load_MNIST()
これだけでOKです。
最適化
次は、学習率にどのような最適化を適用するかを決めます。
今回は、Adam という最適化手法を使ってみましょう。
1optimizer = torch.optim.Adam(params=net.parameters(), lr=0.001)
初期学習率は、0.001 としました。
学習部分の実装
では、核となる学習部分の実装に移ります。
大枠としては、下記のように学習回数のループの中に、訓練データのループとテスト(検証)データのループを作ります。
1 for e in range(epoch):
2
3 """ Training Part"""
4 loss = None
5
6 # 学習開始 (再開)
7 net.train(True) # 引数は省略可能
8
9 for i, (data, target) in enumerate(loaders['train']):
10 pass
11 ########## 学習部分 ##########
12
13 """ Test Part """
14 # 学習のストップ
15 net.eval() # または net.train(False) でも良い
16
17 with torch.no_grad(): # テスト部分では勾配は使わないのでこのように書く
18 for data, target in loaders['test']:
19 pass
20 ########## テスト部分 ##########
それでは、早速中身を書いていきましょう。
訓練部分の実装
訓練部分は、以下のようにコーディングしてみました。
1""" Training Part"""
2 loss = None
3 # 学習開始 (再開)
4 net.train(True) # 引数は省略可能
5 for i, (data, target) in enumerate(loaders['train']):
6 # 全結合のみのネットワークでは入力を1次元に
7 # print(data.shape) # torch.Size([128, 1, 28, 28])
8 data = data.view(-1, 28*28)
9 # print(data.shape) # torch.Size([128, 784])
10
11 optimizer.zero_grad()
12 output = net(data)
13 loss = f.nll_loss(output, target)
14 loss.backward()
15 optimizer.step()
16
17 if i % 10 == 0:
18 print('Training log: {} epoch ({} / 60000 train. data). Loss: {}'.format(e+1,
19 (i+1)*128,
20 loss.item())
21 )
22
23 history['train_loss'].append(loss)
ここで実際にデータの形を出力してみると、先ほど話をしたように(batch, channel, dimension)になっていることがわかります。
ネットワークにデータを入力して出力を得るまでは、output = net(data) だけで済むのは簡単ですね!
今回入力は、1次元でグレースケールなので、data = data.view(-1, 28*28) で形を調整します。
そのあとは、ロスを計算(loss = f.nll_loss(output, target) )して、そのロスを元に誤差を逆伝播(loss.backward() )しているだけです。
ログは、10batch 毎に出力するようにしてみました。
テスト部分の作成
これで訓練部分は完成したので、テスト部分を作ります。
1""" Test Part """
2 # 学習のストップ
3 net.eval() # または net.train(False) でも良い
4 test_loss = 0
5 correct = 0
6
7 with torch.no_grad():
8 for data, target in loaders['test']:
9 data = data.view(-1, 28 * 28)
10 output = net(data)
11 test_loss += f.nll_loss(output, target, reduction='sum').item()
12 pred = output.argmax(dim=1, keepdim=True)
13 correct += pred.eq(target.view_as(pred)).sum().item()
14
15 test_loss /= 10000
16
17 print('Test loss (avg): {}, Accuracy: {}'.format(test_loss,
18 correct / 10000))
19
20 history['test_loss'].append(test_loss)
21 history['test_acc'].append(correct / 10000)
これも先ほどのコードと似ている部分がたくさんありますね。
テスト部分のロスは全て足して、最後に平均を取ることで、その学習(epoch)でのロスとしています。
また、テスト部分では精度も測りたいので、softmaxの確率出力の中で一番大きいニューロンのインデックスを取得しています (pred = output.argmax(dim=1, keepdim=True) )。
このあと、ラベルと比較して一致しているものを正解数として記録しています。
完成!最終的なコード
最後に、結果を描画する部分を加筆して完成です!
ちなみに今回書いたコードは、「これが正解・最適」というわけではなく、筆者の好みも現れていますので、適宜自分の理解しやすいようにコーディングしてください!
1"""
2PyTorchでMNISTを学習させる
3
4:summary PyTorchで単純な多層パーセプトロンを構築してみる
5:author RightCode Inc. (https://rightcode.co.jp)
6"""
7
8import torch
9import torch.nn.functional as f
10from torch.utils.data import DataLoader
11from torchvision import datasets, transforms
12import matplotlib.pyplot as plt
13
14
15class MyNet(torch.nn.Module):
16 def __init__(self):
17 super(MyNet, self).__init__()
18 self.fc1 = torch.nn.Linear(28*28, 1000)
19 self.fc2 = torch.nn.Linear(1000, 10)
20
21 def forward(self, x):
22 x = self.fc1(x)
23 x = torch.sigmoid(x)
24 x = self.fc2(x)
25
26 return f.log_softmax(x, dim=1)
27
28
29def load_MNIST(batch=128, intensity=1.0):
30 train_loader = torch.utils.data.DataLoader(
31 datasets.MNIST('./data',
32 train=True,
33 download=True,
34 transform=transforms.Compose([
35 transforms.ToTensor(),
36 transforms.Lambda(lambda x: x * intensity)
37 ])),
38 batch_size=batch,
39 shuffle=True)
40
41 test_loader = torch.utils.data.DataLoader(
42 datasets.MNIST('./data',
43 train=False,
44 transform=transforms.Compose([
45 transforms.ToTensor(),
46 transforms.Lambda(lambda x: x * intensity)
47 ])),
48 batch_size=batch,
49 shuffle=True)
50
51 return {'train': train_loader, 'test': test_loader}
52
53
54if __name__ == '__main__':
55 # 学習回数
56 epoch = 20
57
58 # 学習結果の保存用
59 history = {
60 'train_loss': [],
61 'test_loss': [],
62 'test_acc': [],
63 }
64
65 # ネットワークを構築
66 net: torch.nn.Module = MyNet()
67
68 # MNISTのデータローダーを取得
69 loaders = load_MNIST()
70
71 optimizer = torch.optim.Adam(params=net.parameters(), lr=0.001)
72
73 for e in range(epoch):
74
75 """ Training Part"""
76 loss = None
77 # 学習開始 (再開)
78 net.train(True) # 引数は省略可能
79 for i, (data, target) in enumerate(loaders['train']):
80 # 全結合のみのネットワークでは入力を1次元に
81 # print(data.shape) # torch.Size([128, 1, 28, 28])
82 data = data.view(-1, 28*28)
83 # print(data.shape) # torch.Size([128, 784])
84
85 optimizer.zero_grad()
86 output = net(data)
87 loss = f.nll_loss(output, target)
88 loss.backward()
89 optimizer.step()
90
91 if i % 10 == 0:
92 print('Training log: {} epoch ({} / 60000 train. data). Loss: {}'.format(e+1,
93 (i+1)*128,
94 loss.item())
95 )
96
97 history['train_loss'].append(loss)
98
99 """ Test Part """
100 # 学習のストップ
101 net.eval() # または net.train(False) でも良い
102 test_loss = 0
103 correct = 0
104
105 with torch.no_grad():
106 for data, target in loaders['test']:
107 data = data.view(-1, 28 * 28)
108 output = net(data)
109 test_loss += f.nll_loss(output, target, reduction='sum').item()
110 pred = output.argmax(dim=1, keepdim=True)
111 correct += pred.eq(target.view_as(pred)).sum().item()
112
113 test_loss /= 10000
114
115 print('Test loss (avg): {}, Accuracy: {}'.format(test_loss,
116 correct / 10000))
117
118 history['test_loss'].append(test_loss)
119 history['test_acc'].append(correct / 10000)
120
121 # 結果の出力と描画
122 print(history)
123 plt.figure()
124 plt.plot(range(1, epoch+1), history['train_loss'], label='train_loss')
125 plt.plot(range(1, epoch+1), history['test_loss'], label='test_loss')
126 plt.xlabel('epoch')
127 plt.legend()
128 plt.savefig('loss.png')
129
130 plt.figure()
131 plt.plot(range(1, epoch+1), history['test_acc'])
132 plt.title('test accuracy')
133 plt.xlabel('epoch')
134 plt.savefig('test_acc.png')
動作確認
実際に動かしてみると、学習後に以下のような図が得られるはずです!
訓練ロスが若干バタついていますが、テスト精度は98%以上と、しっかり学習できていそうですね!
さいごに
今回は、PyTorchの入門編という立ち位置で「MNISTを単純なネットワークで学習」させてみました。
実際にコードを見てみると、機械学習初心者でも比較的馴染みやすい書き方だと思います。
これから機械学習を始める方、新しい機械学習ライブラリを探していた方、是非一度触ってみてください!
最初に言ったように、おそらく、これからどんどんホットになっていく機械学習ライブラリがPyTorchです!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit
おすすめ記事
浮動小数点について調べてみた
2024.09.09