• トップ
  • ブログ一覧
  • 【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる
  • 【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる

    メディアチームメディアチーム
    2020.01.23

    IT技術

    PyTorchで手書き数字(MNIST)を学習させる

    前回は、PyTorch(パイトーチ)のインストールなどを行いました。

    今回は、いよいよPyTorchで手書き数字(MNIST)データセットを学習させていきたいと思います!

    前回の記事はこちら

    featureImg2020.01.20PyTorchの特徴とインストール方法PyTorchとはPyTorch(パイトーチ)とは、Pythonの機械学習ライブラリの一つで、現在最もアツいフレームワ...

    早速実装してみる

    それでは、実装を始めます。

    入門編ということで、一つ一つ丁寧に解説していきます!

    必要なモジュールのインポート

    今回使うモジュールを先に公開しておきます。

    1import torch
    2import torch.nn.functional as f
    3from torch.utils.data import DataLoader
    4from torchvision import datasets, transforms
    5import matplotlib.pyplot as plt

    今回は、これだけ使用します。

    (つまり最低でも、前回インストールした2つのモジュールがあればOKです)

    ネットワークの構築

    まずは、今回使うネットワークを定義していきます。

    PyTorchでは、torch.nn.Module というクラスを継承して、オリジナルのネットワークを構築していきます。

    今回はMyNet という名前でネットワークを作っていきますが、ネットワーク構成はシンプルに「入力層(784) - 中間層(1000) - 出力層(10)」の3層構造とします。

    1. 中間層の活性化関数に「シグモイド(sigmoid)関数」
    2. 出力は確率にしたいので「ソフトマックス(softmax)関数」

    中間層の活性化関数に「シグモイド( sigmoid )関数」を、出力は確率にしたいので「ソフトマックス( softmax )関数」を使用します。

    今回は、MNISTという簡単なタスクで、なおかつ畳み込み層はないので、よく使用される ReLU関数 は使いません。(もちろん使っても良いです笑)

    PyTorchでは、以上のようなネットワークの場合以下のように定義していきます。

    1class MyNet(torch.nn.Module):
    2    def __init__(self):
    3        super(MyNet, self).__init__()
    4        self.fc1 = torch.nn.Linear(28*28, 1000)
    5        self.fc2 = torch.nn.Linear(1000, 10)
    6
    7    def forward(self, x):
    8        x = self.fc1(x)
    9        x = torch.sigmoid(x)
    10        x = self.fc2(x)
    11
    12        return f.log_softmax(x, dim=1)

    とてもシンプルです。

    最低限、コンストラクタ(def __init__() )と順伝播の関数(def forward() )を定義すればOKです。

    データセット(MNIST)のロード

    MNISTをロードする関数を作りましょう。

    PyTorchでは、TorchVisionというモジュールでデータセットを管理しています。

    まずは、出来上がった関数を見てみましょう。

    1def load_MNIST(batch=128, intensity=1.0):
    2    train_loader = torch.utils.data.DataLoader(
    3        datasets.MNIST('./data',
    4                       train=True,
    5                       download=True,
    6                       transform=transforms.Compose([
    7                           transforms.ToTensor(),
    8                           transforms.Lambda(lambda x: x * intensity)
    9                       ])),
    10        batch_size=batch,
    11        shuffle=True)
    12
    13    test_loader = torch.utils.data.DataLoader(
    14        datasets.MNIST('./data',
    15                       train=False,
    16                       transform=transforms.Compose([
    17                           transforms.ToTensor(),
    18                           transforms.Lambda(lambda x: x * intensity)
    19                       ])),
    20        batch_size=batch,
    21        shuffle=True)
    22
    23    return {'train': train_loader, 'test': test_loader}

    PyTorchでは、データローダーという形でデータを取り扱うことが大きな特徴の一つです。

    このデータローダーには、バッチサイズごとにまとめられたデータとラベルがまとまっています

    さらにデータは、torch.tensor というテンソルの形で扱いますが、データローダーにおけるデータの形は(batch, channel, dimension)という順番になっています。

    これは後で、実際に見てみましょう。

    また、torch.utils.data.DataLoader() では、第一引数に「データセット」を取ります。

    今回は、その第一引数にdatasets.MNIST() というMNISTのデータを扱うためのクラスインスタンスが与えられていることが分かります。

    このクラス(datasets.MNIST())では、コンストラクタとして第一引数にデータのダウンロード先を指定し、そのほかに訓練データか否か(train=True なら訓練データ、train=False ならテストデータ)を指定したり、transform= でデータを正規化したりできます。

    今回は、画素値の最大値をintensity 倍するような形ですが、他によく見る形として、

    1    train_loader = torch.utils.data.DataLoader(
    2        datasets.MNIST('./data',
    3                       train=True,
    4                       download=True,
    5                       transform=transforms.Compose([
    6                           transforms.ToTensor(),
    7                           transforms.Normalize((0.5,), (0.5,))  # ここが違う
    8                       ])),
    9        batch_size=batch,
    10        shuffle=True)

    のように、平均と分散を指定すると良い精度になる場合もあります。

    今回用意した関数では、戻り値として各ローダーを辞書型変数にして返しています。

    メイン処理部分を書く

    下準備完了です!

    早速学習させる部分を実装していきます。

    まずは、ネットワークを構築して、データを取得するまでを示します。

    1if __name__ == '__main__':
    2    # 学習回数
    3    epoch = 20
    4
    5    # 学習結果の保存用
    6    history = {
    7        'train_loss': [],
    8        'test_loss': [],
    9        'test_acc': [],
    10    }
    11
    12    # ネットワークを構築
    13    net: torch.nn.Module = MyNet()
    14    
    15    # MNISTのデータローダーを取得
    16    loaders = load_MNIST()

    これだけでOKです。

    最適化

    次は、学習率にどのような最適化を適用するかを決めます。

    今回は、Adam という最適化手法を使ってみましょう。

    1optimizer = torch.optim.Adam(params=net.parameters(), lr=0.001)

    初期学習率は、0.001 としました。

    学習部分の実装

    では、核となる学習部分の実装に移ります。

    大枠としては、下記のように学習回数のループの中に、訓練データのループとテスト(検証)データのループを作ります。

    1    for e in range(epoch):
    2
    3        """ Training Part"""
    4        loss = None
    5
    6        # 学習開始 (再開)
    7        net.train(True)  # 引数は省略可能
    8
    9        for i, (data, target) in enumerate(loaders['train']):
    10            pass
    11           ########## 学習部分 ##########
    12
    13        """ Test Part """
    14        # 学習のストップ
    15        net.eval()  # または net.train(False) でも良い
    16
    17        with torch.no_grad():  # テスト部分では勾配は使わないのでこのように書く
    18            for data, target in loaders['test']:
    19                pass
    20                ########## テスト部分 ##########

    それでは、早速中身を書いていきましょう。

    訓練部分の実装

    訓練部分は、以下のようにコーディングしてみました。

    1""" Training Part"""
    2        loss = None
    3        # 学習開始 (再開)
    4        net.train(True)  # 引数は省略可能
    5        for i, (data, target) in enumerate(loaders['train']):
    6            # 全結合のみのネットワークでは入力を1次元に
    7            # print(data.shape)  # torch.Size([128, 1, 28, 28])
    8            data = data.view(-1, 28*28)
    9            # print(data.shape)  # torch.Size([128, 784])
    10
    11            optimizer.zero_grad()
    12            output = net(data)
    13            loss = f.nll_loss(output, target)
    14            loss.backward()
    15            optimizer.step()
    16
    17            if i % 10 == 0:
    18                print('Training log: {} epoch ({} / 60000 train. data). Loss: {}'.format(e+1,
    19                                                                                         (i+1)*128,
    20                                                                                         loss.item())
    21                      )
    22
    23        history['train_loss'].append(loss)

    ここで実際にデータの形を出力してみると、先ほど話をしたように(batch, channel, dimension)になっていることがわかります。

    ネットワークにデータを入力して出力を得るまでは、output = net(data) だけで済むのは簡単ですね!

    今回入力は、1次元でグレースケールなので、data = data.view(-1, 28*28) で形を調整します。

    そのあとは、ロスを計算(loss = f.nll_loss(output, target) )して、そのロスを元に誤差を逆伝播(loss.backward() )しているだけです。

    ログは、10batch 毎に出力するようにしてみました。

    テスト部分の作成

    これで訓練部分は完成したので、テスト部分を作ります。

    1""" Test Part """
    2        # 学習のストップ
    3        net.eval()  # または net.train(False) でも良い
    4        test_loss = 0
    5        correct = 0
    6
    7        with torch.no_grad():
    8            for data, target in loaders['test']:
    9                data = data.view(-1, 28 * 28)
    10                output = net(data)
    11                test_loss += f.nll_loss(output, target, reduction='sum').item()
    12                pred = output.argmax(dim=1, keepdim=True)
    13                correct += pred.eq(target.view_as(pred)).sum().item()
    14
    15        test_loss /= 10000
    16
    17        print('Test loss (avg): {}, Accuracy: {}'.format(test_loss,
    18                                                         correct / 10000))
    19
    20        history['test_loss'].append(test_loss)
    21        history['test_acc'].append(correct / 10000)

    これも先ほどのコードと似ている部分がたくさんありますね。

    テスト部分のロスは全て足して、最後に平均を取ることで、その学習(epoch)でのロスとしています。

    また、テスト部分では精度も測りたいので、softmaxの確率出力の中で一番大きいニューロンのインデックスを取得しています (pred = output.argmax(dim=1, keepdim=True) )。

    このあと、ラベルと比較して一致しているものを正解数として記録しています。

    完成!最終的なコード

    最後に、結果を描画する部分を加筆して完成です!

    ちなみに今回書いたコードは、「これが正解・最適」というわけではなく、筆者の好みも現れていますので、適宜自分の理解しやすいようにコーディングしてください!

    1"""
    2PyTorchでMNISTを学習させる
    3
    4:summary   PyTorchで単純な多層パーセプトロンを構築してみる
    5:author    RightCode Inc. (https://rightcode.co.jp)
    6"""
    7
    8import torch
    9import torch.nn.functional as f
    10from torch.utils.data import DataLoader
    11from torchvision import datasets, transforms
    12import matplotlib.pyplot as plt
    13
    14
    15class MyNet(torch.nn.Module):
    16    def __init__(self):
    17        super(MyNet, self).__init__()
    18        self.fc1 = torch.nn.Linear(28*28, 1000)
    19        self.fc2 = torch.nn.Linear(1000, 10)
    20
    21    def forward(self, x):
    22        x = self.fc1(x)
    23        x = torch.sigmoid(x)
    24        x = self.fc2(x)
    25
    26        return f.log_softmax(x, dim=1)
    27
    28
    29def load_MNIST(batch=128, intensity=1.0):
    30    train_loader = torch.utils.data.DataLoader(
    31        datasets.MNIST('./data',
    32                       train=True,
    33                       download=True,
    34                       transform=transforms.Compose([
    35                           transforms.ToTensor(),
    36                           transforms.Lambda(lambda x: x * intensity)
    37                       ])),
    38        batch_size=batch,
    39        shuffle=True)
    40
    41    test_loader = torch.utils.data.DataLoader(
    42        datasets.MNIST('./data',
    43                       train=False,
    44                       transform=transforms.Compose([
    45                           transforms.ToTensor(),
    46                           transforms.Lambda(lambda x: x * intensity)
    47                       ])),
    48        batch_size=batch,
    49        shuffle=True)
    50
    51    return {'train': train_loader, 'test': test_loader}
    52
    53
    54if __name__ == '__main__':
    55    # 学習回数
    56    epoch = 20
    57
    58    # 学習結果の保存用
    59    history = {
    60        'train_loss': [],
    61        'test_loss': [],
    62        'test_acc': [],
    63    }
    64
    65    # ネットワークを構築
    66    net: torch.nn.Module = MyNet()
    67
    68    # MNISTのデータローダーを取得
    69    loaders = load_MNIST()
    70
    71    optimizer = torch.optim.Adam(params=net.parameters(), lr=0.001)
    72
    73    for e in range(epoch):
    74
    75        """ Training Part"""
    76        loss = None
    77        # 学習開始 (再開)
    78        net.train(True)  # 引数は省略可能
    79        for i, (data, target) in enumerate(loaders['train']):
    80            # 全結合のみのネットワークでは入力を1次元に
    81            # print(data.shape)  # torch.Size([128, 1, 28, 28])
    82            data = data.view(-1, 28*28)
    83            # print(data.shape)  # torch.Size([128, 784])
    84
    85            optimizer.zero_grad()
    86            output = net(data)
    87            loss = f.nll_loss(output, target)
    88            loss.backward()
    89            optimizer.step()
    90
    91            if i % 10 == 0:
    92                print('Training log: {} epoch ({} / 60000 train. data). Loss: {}'.format(e+1,
    93                                                                                         (i+1)*128,
    94                                                                                         loss.item())
    95                      )
    96
    97        history['train_loss'].append(loss)
    98
    99        """ Test Part """
    100        # 学習のストップ
    101        net.eval()  # または net.train(False) でも良い
    102        test_loss = 0
    103        correct = 0
    104
    105        with torch.no_grad():
    106            for data, target in loaders['test']:
    107                data = data.view(-1, 28 * 28)
    108                output = net(data)
    109                test_loss += f.nll_loss(output, target, reduction='sum').item()
    110                pred = output.argmax(dim=1, keepdim=True)
    111                correct += pred.eq(target.view_as(pred)).sum().item()
    112
    113        test_loss /= 10000
    114
    115        print('Test loss (avg): {}, Accuracy: {}'.format(test_loss,
    116                                                         correct / 10000))
    117
    118        history['test_loss'].append(test_loss)
    119        history['test_acc'].append(correct / 10000)
    120
    121    # 結果の出力と描画
    122    print(history)
    123    plt.figure()
    124    plt.plot(range(1, epoch+1), history['train_loss'], label='train_loss')
    125    plt.plot(range(1, epoch+1), history['test_loss'], label='test_loss')
    126    plt.xlabel('epoch')
    127    plt.legend()
    128    plt.savefig('loss.png')
    129
    130    plt.figure()
    131    plt.plot(range(1, epoch+1), history['test_acc'])
    132    plt.title('test accuracy')
    133    plt.xlabel('epoch')
    134    plt.savefig('test_acc.png')

    動作確認

    実際に動かしてみると、学習後に以下のような図が得られるはずです!

    PyTorchでMNIST_Loss

    PyTorchでMNIST_Accuracy

    訓練ロスが若干バタついていますが、テスト精度は98%以上と、しっかり学習できていそうですね!

    さいごに

    今回は、PyTorchの入門編という立ち位置で「MNISTを単純なネットワークで学習」させてみました。

    実際にコードを見てみると、機械学習初心者でも比較的馴染みやすい書き方だと思います。

    これから機械学習を始める方、新しい機械学習ライブラリを探していた方、是非一度触ってみてください!

    最初に言ったように、おそらく、これからどんどんホットになっていく機械学習ライブラリがPyTorchです!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    ライトコードでは、エンジニアを積極採用中!

    ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。

    採用情報へ

    メディアチーム
    メディアチーム
    Show more...

    おすすめ記事

    エンジニア大募集中!

    ライトコードでは、エンジニアを積極採用中です。

    特に、WEBエンジニアとモバイルエンジニアは是非ご応募お待ちしております!

    また、フリーランスエンジニア様も大募集中です。

    background