• トップ
  • ブログ一覧
  • 【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
  • 【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

    広告メディア事業部広告メディア事業部
    2020.02.13

    IT技術

    【後編】PyTorchでCIFAR-10をCNNに学習させる

    【前編】の続きとなります。

    引き続き、PyTorch(パイトーチ)で畳み込みニューラルネットワーク(CNN)を実装していきたいと思います。

    今回は、学習結果からとなります!

    前編の記事はこちら

    featureImg2020.02.07【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】PyTorchでCIFAR-10をCNNに学習させる前回の『【PyTorch入門】PyTorchで手書き数字(MNIS...

    学習結果

    学習が終わりましたが、やはり「MNIST」と違って学習に時間がかかりますね!

    ですが、50エポックなので数十分で終わるかと思います。(マシンスペックに依存しますが...)

    訓練ロスと訓練 / テスト精度

    学習によって得られた、『訓練ロス』『訓練 / テスト精度』から見てみましょう。

    学習結果 (訓練Loss)
    学習結果 (精度)

    まだ、精度は「70%程度」と低いですが、しっかり学習できていそうですね!

    畳み込み層のフィルタ

    ちなみに、学習後の畳み込み層のフィルタを見てみると

    学習後 のconv.1の重み (50 epoch)
    学習後のconv.2の重み (50 epoch)

    学習前と比べて、何かしらフィルタに模様が見えてきましたね。

    よく見ると、斜め方向に対応するフィルタや、横方向に対応するフィルタが見受けられますが、まだはっきりとは分かりませんね。

    またグラフを見ると、学習回数を増やせば、まだ精度は伸びそうな雰囲気があります。

    「もっともっと学習を増やしてみましょう!」

    …と言いたいところですが、そうなると学習に膨大な時間がかかってしまいそうです。

    GPUを使ってみる

    CUDA(Compute Unified Device Architecture:クーダ)」が使用可能であれば、PyTorchでは、簡単にGPUに演算を行わせることができます。

    メイン処理部の冒頭に、以下を加筆して下さい。

    1if __name__ == '__main__':
    2    epoch = 300  # 今回は300エポック!
    3
    4    loader = load_cifar10()
    5    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    6               'dog', 'frog', 'horse', 'ship', 'truck')
    7
    8    net: MyCNN = MyCNN()
    9    criterion = torch.nn.CrossEntropyLoss()  # ロスの計算
    10    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
    11
    12    """ 以下を加筆 """
    13    # もしGPUが使えるなら使う
    14    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    15    net.to(device)
    16    print(device)

    torch.cuda.is_available() は、「CUDA」が使用可能ならTrueを返すといった、GPU を使用できるかを簡単に確認できる関数です。

    その後、ネットワークを to() でデバイスに投げるだけです。

    しかし、これだけではエラーを吐かれてしまいます。

    GPU に使用するデータセットを投げる

    使用するデータセットも、GPU に投げる必要があるので、各データローダーのループに、以下のように加筆して下さい。

    1# ...
    2for i, (images, labels) in enumerate(loader['train']):
    3    images = images.to(device)  # to GPU?
    4    labels = labels.to(device)
    5    # ...

    matplotlib で重みを描画するには

    最後に、重みを描画する「matplotlib」では、CPU で描画するため、その際は逆に CPU に投げる必要があります

    以下のように書けば OK です!

    1# ...
    2plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')  # to('cpu')を追加!
    3# ...

    これで、GPU を使用する準備が整いました!

    早速学習させてみましょう!

    ちなみに筆者の環境は、GPU は「NVIDIA GeForce GTX Titan Blak 6GB」で、「CUDA」はバーション9.2で動作確認をしています。

    学習結果 (300エポック)

    さすがGPUを使うと、目に見えて学習が早いです!

    時間はしっかりと測定していませんが、1.5 ~ 2倍くらい早いです。

    では早速、300エポックの学習結果を見てみましょう!

    300エポックの学習結果...?

    なんと、訓練精度は「約100%」になりましたが、テスト精度は「60%程度」になってしまいました

    「過学習」が起こってしまいました!

    【過学習とは?】
    過学習とは、このように訓練データに適応しすぎて、テストデータなどに対する性能、いわゆる汎化性能が低下してしまう現象を言います。

    Dropoutを導入してみる

    それでは、過学習対策としてメジャーな「Dropout」を導入してみましょう。

    「Dropout」とは、簡単に言えば、学習時に一部のニューロンを、わざと非活性化させ、訓練データに適合しすぎないようにする手法です。

    1class MyCNN(torch.nn.Module):
    2    def __init__(self):
    3        super(MyCNN, self).__init__()
    4        self.conv1 = torch.nn.Conv2d(3,  # チャネル入力
    5                                     6,  # チャンネル出力
    6                                     5,  # カーネルサイズ
    7                                     1,  # ストライド (デフォルトは1)
    8                                     0,  # パディング (デフォルトは0)
    9                                     )
    10        self.conv2 = torch.nn.Conv2d(6, 16, 5)
    11
    12        self.pool = torch.nn.MaxPool2d(2, 2)  # カーネルサイズ, ストライド
    13
    14        self.dropout1 = torch.nn.Dropout2d(p=0.3)  # [new] Dropoutを追加してみる
    15
    16        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)  # 入力サイズ, 出力サイズ
    17        self.dropout2 = torch.nn.Dropout(p=0.5)  # [new] Dropoutを追加してみる
    18        self.fc2 = torch.nn.Linear(120, 84)
    19        self.fc3 = torch.nn.Linear(84, 10)
    20
    21    def forward(self, x):
    22        x = f.relu(self.conv1(x))
    23        x = self.pool(x)
    24        x = f.relu(self.conv2(x))
    25        x = self.pool(x)
    26        x = self.dropout1(x)  # [new] Dropoutを追加
    27        x = x.view(-1, 16 * 5 * 5)  # 1次元データに変えて全結合層へ
    28        x = f.relu(self.fc1(x))
    29        x = self.dropout2(x)   # [new] Dropoutを追加
    30        x = f.relu(self.fc2(x))
    31        x = self.fc3(x)
    32
    33        return x

    追加してみました!

    さて、結果は、どうなるでしょうか!?

    実験結果 (300エポック + Dropout)

    学習結果

    過学習はなくなりましたね!

    ただ、やはり精度は「70%程度」といったところでしょうか。

    もしかしたら、これが「LeNet」の限界なのかもしれません...。

    他にも「Batch Normalization(バッチ正規化)」や、「Data Augmentation(データ拡張)」などの手法を用いれば、過学習を抑制しつつ精度向上が見込めるかもしれません。

    ですが、本記事ではここまでとします。

    ここまできたら、もう少し深い層の、畳み込みニューラルネットワーク (DCNN: Deep Convolutional Neural Networks)を構築した方が良いでしょう。

    フィルタの重みの学習結果

    ちなみに、フィルタの重みの学習結果も載せておきますが、50エポックの時と、大した差はありませんね。

    やや、模様が明確になったような気もします(笑)

    Conv.1のフィルタ
    Conv.2のフィルタ

    さいごに

    長丁場になりましたが、今回は、PyTorchで畳み込みニューラルネットワークを構築し、カラー画像の「CIFAR10」を学習させてみました。

    また、ネットワークの内部を観察したり、いろいろな考察もしてみたので、機械学習初学者の皆さまの参考になれば幸いです。

    次回は、もう少し複雑なネットワークで試してみようと考えているのでお楽しみに!

    ソースコード

    1import torch
    2import torch.nn.functional as f
    3from torch.utils.data import DataLoader
    4from torchvision import datasets, transforms
    5import matplotlib.pyplot as plt
    6from tqdm import tqdm
    7
    8
    9class MyCNN(torch.nn.Module):
    10    def __init__(self):
    11        super(MyCNN, self).__init__()
    12        self.conv1 = torch.nn.Conv2d(3,  # チャネル入力
    13                                     6,  # チャンネル出力
    14                                     5,  # カーネルサイズ
    15                                     1,  # ストライド (デフォルトは1)
    16                                     0,  # パディング (デフォルトは0)
    17                                     )
    18        self.conv2 = torch.nn.Conv2d(6, 16, 5)
    19
    20        self.pool = torch.nn.MaxPool2d(2, 2)  # カーネルサイズ, ストライド
    21
    22        self.dropout1 = torch.nn.Dropout2d(p=0.3)  # [new] Dropoutを追加してみる
    23
    24        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)  # 入力サイズ, 出力サイズ
    25        self.dropout2 = torch.nn.Dropout(p=0.5)  # [new] Dropoutを追加してみる
    26        self.fc2 = torch.nn.Linear(120, 84)
    27        self.fc3 = torch.nn.Linear(84, 10)
    28
    29    def forward(self, x):
    30        x = f.relu(self.conv1(x))
    31        x = self.pool(x)
    32        x = f.relu(self.conv2(x))
    33        x = self.pool(x)
    34        x = self.dropout1(x)  # [new] Dropoutを追加
    35        x = x.view(-1, 16 * 5 * 5)  # 1次元データに変えて全結合層へ
    36        x = f.relu(self.fc1(x))
    37        x = self.dropout2(x)   # [new] Dropoutを追加
    38        x = f.relu(self.fc2(x))
    39        x = self.fc3(x)
    40
    41        return x
    42
    43    def plot_conv1(self, prefix_num=0):
    44        weights1 = self.conv1.weight
    45        weights1 = weights1.reshape(3*6, 5, 5)
    46
    47        for i, weight in enumerate(weights1):
    48            plt.subplot(3, 6, i + 1)
    49            plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')
    50            plt.tick_params(labelbottom=False,
    51                            labelleft=False,
    52                            labelright=False,
    53                            labeltop=False,
    54                            bottom=False,
    55                            left=False,
    56                            right=False,
    57                            top=False)
    58
    59        plt.savefig('img/{}_conv1.png'.format(prefix_num))
    60        plt.close()
    61
    62    def plot_conv2(self, prefix_num=0):
    63        weights2 = self.conv2.weight
    64        weights2 = weights2.reshape(6*16, 5, 5)
    65
    66        for i, weight in enumerate(weights2):
    67            plt.subplot(6, 16, i + 1)
    68            plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')
    69            plt.tick_params(labelbottom=False,
    70                            labelleft=False,
    71                            labelright=False,
    72                            labeltop=False,
    73                            bottom=False,
    74                            left=False,
    75                            right=False,
    76                            top=False)
    77
    78        plt.savefig('img/{}_conv2.png'.format(prefix_num))
    79        plt.close()
    80
    81
    82def load_cifar10(batch=128):
    83    train_loader = DataLoader(
    84        datasets.CIFAR10('./data',
    85                         train=True,
    86                         download=True,
    87                         transform=transforms.Compose([
    88                             transforms.ToTensor(),
    89                             transforms.Normalize(
    90                                [0.5, 0.5, 0.5],  # RGB 平均
    91                                [0.5, 0.5, 0.5]   # RGB 標準偏差
    92                                )
    93                         ])),
    94        batch_size=batch,
    95        shuffle=True
    96    )
    97
    98    test_loader = DataLoader(
    99        datasets.CIFAR10('./data',
    100                         train=False,
    101                         download=True,
    102                         transform=transforms.Compose([
    103                             transforms.ToTensor(),
    104                             transforms.Normalize(
    105                                 [0.5, 0.5, 0.5],  # RGB 平均
    106                                 [0.5, 0.5, 0.5]  # RGB 標準偏差
    107                             )
    108                         ])),
    109        batch_size=batch,
    110        shuffle=True
    111    )
    112
    113    return {'train': train_loader, 'test': test_loader}
    114
    115
    116if __name__ == '__main__':
    117    epoch = 300
    118
    119    loader = load_cifar10()
    120    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    121               'dog', 'frog', 'horse', 'ship', 'truck')
    122
    123    net: MyCNN = MyCNN()
    124    criterion = torch.nn.CrossEntropyLoss()  # ロスの計算
    125    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
    126
    127    # もしGPUが使えるなら使う
    128    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    129    net.to(device)
    130    print(device)
    131
    132    # 学習前のフィルタの可視化
    133    net.plot_conv1()
    134    net.plot_conv2()
    135
    136    history = {
    137        'train_loss': [],
    138        'train_acc': [],
    139        'test_acc': []
    140    }
    141
    142    for e in range(epoch):
    143        net.train()
    144        loss = None
    145        for i, (images, labels) in enumerate(loader['train']):
    146            images = images.to(device)  # to GPU?
    147            labels = labels.to(device)
    148
    149            optimizer.zero_grad()
    150            output = net(images)
    151            loss = criterion(output, labels)
    152            loss.backward()
    153            optimizer.step()
    154
    155            if i % 10 == 0:
    156                print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1,
    157                                                                                         (i + 1) * 128,
    158                                                                                         loss.item())
    159                      )
    160
    161        # 学習過程でのフィルタの可視化
    162        # net.plot_conv1(e+1)
    163        # net.plot_conv2(e+1)
    164
    165        history['train_loss'].append(loss.item())
    166
    167        net.eval()
    168        correct = 0
    169        with torch.no_grad():
    170            for i, (images, labels) in enumerate(tqdm(loader['train'])):
    171                images = images.to(device)  # to GPU?
    172                labels = labels.to(device)
    173
    174                outputs = net(images)
    175                _, predicted = torch.max(outputs.data, 1)
    176                correct += (predicted == labels).sum().item()
    177
    178        acc = float(correct / 50000)
    179        history['train_acc'].append(acc)
    180
    181        correct = 0
    182        with torch.no_grad():
    183            for i, (images, labels) in enumerate(tqdm(loader['test'])):
    184                images = images.to(device)  # to GPU?
    185                labels = labels.to(device)
    186
    187                outputs = net(images)
    188                _, predicted = torch.max(outputs.data, 1)
    189                correct += (predicted == labels).sum().item()
    190
    191        acc = float(correct / 10000)
    192        history['test_acc'].append(acc)
    193
    194    # 学習前のフィルタの可視化
    195    net.plot_conv1(300)
    196    net.plot_conv2(300)
    197
    198    # 結果をプロット
    199    plt.plot(range(1, epoch+1), history['train_loss'])
    200    plt.title('Training Loss [CIFAR10]')
    201    plt.xlabel('epoch')
    202    plt.ylabel('loss')
    203    plt.savefig('img/cifar10_loss.png')
    204    plt.close()
    205
    206    plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc')
    207    plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc')
    208    plt.title('Accuracies [CIFAR10]')
    209    plt.xlabel('epoch')
    210    plt.ylabel('accuracy')
    211    plt.legend()
    212    plt.savefig('img/cifar10_acc.png')
    213    plt.close()

    前編の記事はこちら

    featureImg2020.02.07【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】PyTorchでCIFAR-10をCNNに学習させる前回の『【PyTorch入門】PyTorchで手書き数字(MNIS...

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    関連記事

    featureImg2020.01.23【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させるPyTorchで手書き数字(MNIST)を学習させる前回は、PyTorch(パイトーチ)のインストールなどを行いました...

    広告メディア事業部

    広告メディア事業部

    おすすめ記事