• トップ
  • ブログ一覧
  • 【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
  • 【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

    広告メディア事業部広告メディア事業部
    2020.02.13

    IT技術

    【後編】PyTorchでCIFAR-10をCNNに学習させる

    【前編】の続きとなります。

    引き続き、PyTorch(パイトーチ)で畳み込みニューラルネットワーク(CNN)を実装していきたいと思います。

    今回は、学習結果からとなります!

    前編の記事はこちら

    featureImg2020.02.07【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】PyTorchでCIFAR-10をCNNに学習させる前回の『【PyTorch入門】PyTorchで手書き数字(MNIS...

    学習結果

    学習が終わりましたが、やはり「MNIST」と違って学習に時間がかかりますね!

    ですが、50エポックなので数十分で終わるかと思います。(マシンスペックに依存しますが...)

    訓練ロスと訓練 / テスト精度

    学習によって得られた、『訓練ロス』『訓練 / テスト精度』から見てみましょう。

    学習結果 (訓練Loss)
    学習結果 (精度)

    まだ、精度は「70%程度」と低いですが、しっかり学習できていそうですね!

    畳み込み層のフィルタ

    ちなみに、学習後の畳み込み層のフィルタを見てみると

    学習後 のconv.1の重み (50 epoch)
    学習後のconv.2の重み (50 epoch)

    学習前と比べて、何かしらフィルタに模様が見えてきましたね。

    よく見ると、斜め方向に対応するフィルタや、横方向に対応するフィルタが見受けられますが、まだはっきりとは分かりませんね。

    またグラフを見ると、学習回数を増やせば、まだ精度は伸びそうな雰囲気があります。

    「もっともっと学習を増やしてみましょう!」

    …と言いたいところですが、そうなると学習に膨大な時間がかかってしまいそうです。

    GPUを使ってみる

    CUDA(Compute Unified Device Architecture:クーダ)」が使用可能であれば、PyTorchでは、簡単にGPUに演算を行わせることができます。

    メイン処理部の冒頭に、以下を加筆して下さい。

    1if __name__ == '__main__':
    2    epoch = 300  # 今回は300エポック!
    3
    4    loader = load_cifar10()
    5    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    6               'dog', 'frog', 'horse', 'ship', 'truck')
    7
    8    net: MyCNN = MyCNN()
    9    criterion = torch.nn.CrossEntropyLoss()  # ロスの計算
    10    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
    11
    12    """ 以下を加筆 """
    13    # もしGPUが使えるなら使う
    14    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    15    net.to(device)
    16    print(device)

    torch.cuda.is_available() は、「CUDA」が使用可能ならTrueを返すといった、GPU を使用できるかを簡単に確認できる関数です。

    その後、ネットワークを to() でデバイスに投げるだけです。

    しかし、これだけではエラーを吐かれてしまいます。

    GPU に使用するデータセットを投げる

    使用するデータセットも、GPU に投げる必要があるので、各データローダーのループに、以下のように加筆して下さい。

    1# ...
    2for i, (images, labels) in enumerate(loader['train']):
    3    images = images.to(device)  # to GPU?
    4    labels = labels.to(device)
    5    # ...

    matplotlib で重みを描画するには

    最後に、重みを描画する「matplotlib」では、CPU で描画するため、その際は逆に CPU に投げる必要があります

    以下のように書けば OK です!

    1# ...
    2plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')  # to('cpu')を追加!
    3# ...

    これで、GPU を使用する準備が整いました!

    早速学習させてみましょう!

    ちなみに筆者の環境は、GPU は「NVIDIA GeForce GTX Titan Blak 6GB」で、「CUDA」はバーション9.2で動作確認をしています。

    学習結果 (300エポック)

    さすがGPUを使うと、目に見えて学習が早いです!

    時間はしっかりと測定していませんが、1.5 ~ 2倍くらい早いです。

    では早速、300エポックの学習結果を見てみましょう!

    300エポックの学習結果...?

    なんと、訓練精度は「約100%」になりましたが、テスト精度は「60%程度」になってしまいました

    「過学習」が起こってしまいました!

    【過学習とは?】
    過学習とは、このように訓練データに適応しすぎて、テストデータなどに対する性能、いわゆる汎化性能が低下してしまう現象を言います。

    Dropoutを導入してみる

    それでは、過学習対策としてメジャーな「Dropout」を導入してみましょう。

    「Dropout」とは、簡単に言えば、学習時に一部のニューロンを、わざと非活性化させ、訓練データに適合しすぎないようにする手法です。

    1class MyCNN(torch.nn.Module):
    2    def __init__(self):
    3        super(MyCNN, self).__init__()
    4        self.conv1 = torch.nn.Conv2d(3,  # チャネル入力
    5                                     6,  # チャンネル出力
    6                                     5,  # カーネルサイズ
    7                                     1,  # ストライド (デフォルトは1)
    8                                     0,  # パディング (デフォルトは0)
    9                                     )
    10        self.conv2 = torch.nn.Conv2d(6, 16, 5)
    11
    12        self.pool = torch.nn.MaxPool2d(2, 2)  # カーネルサイズ, ストライド
    13
    14        self.dropout1 = torch.nn.Dropout2d(p=0.3)  # [new] Dropoutを追加してみる
    15
    16        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)  # 入力サイズ, 出力サイズ
    17        self.dropout2 = torch.nn.Dropout(p=0.5)  # [new] Dropoutを追加してみる
    18        self.fc2 = torch.nn.Linear(120, 84)
    19        self.fc3 = torch.nn.Linear(84, 10)
    20
    21    def forward(self, x):
    22        x = f.relu(self.conv1(x))
    23        x = self.pool(x)
    24        x = f.relu(self.conv2(x))
    25        x = self.pool(x)
    26        x = self.dropout1(x)  # [new] Dropoutを追加
    27        x = x.view(-1, 16 * 5 * 5)  # 1次元データに変えて全結合層へ
    28        x = f.relu(self.fc1(x))
    29        x = self.dropout2(x)   # [new] Dropoutを追加
    30        x = f.relu(self.fc2(x))
    31        x = self.fc3(x)
    32
    33        return x

    追加してみました!

    さて、結果は、どうなるでしょうか!?

    実験結果 (300エポック + Dropout)

    学習結果

    過学習はなくなりましたね!

    ただ、やはり精度は「70%程度」といったところでしょうか。

    もしかしたら、これが「LeNet」の限界なのかもしれません...。

    他にも「Batch Normalization(バッチ正規化)」や、「Data Augmentation(データ拡張)」などの手法を用いれば、過学習を抑制しつつ精度向上が見込めるかもしれません。

    ですが、本記事ではここまでとします。

    ここまできたら、もう少し深い層の、畳み込みニューラルネットワーク (DCNN: Deep Convolutional Neural Networks)を構築した方が良いでしょう。

    フィルタの重みの学習結果

    ちなみに、フィルタの重みの学習結果も載せておきますが、50エポックの時と、大した差はありませんね。

    やや、模様が明確になったような気もします(笑)

    Conv.1のフィルタ
    Conv.2のフィルタ

    さいごに

    長丁場になりましたが、今回は、PyTorchで畳み込みニューラルネットワークを構築し、カラー画像の「CIFAR10」を学習させてみました。

    また、ネットワークの内部を観察したり、いろいろな考察もしてみたので、機械学習初学者の皆さまの参考になれば幸いです。

    次回は、もう少し複雑なネットワークで試してみようと考えているのでお楽しみに!

    ソースコード

    1import torch
    2import torch.nn.functional as f
    3from torch.utils.data import DataLoader
    4from torchvision import datasets, transforms
    5import matplotlib.pyplot as plt
    6from tqdm import tqdm
    7
    8
    9class MyCNN(torch.nn.Module):
    10    def __init__(self):
    11        super(MyCNN, self).__init__()
    12        self.conv1 = torch.nn.Conv2d(3,  # チャネル入力
    13                                     6,  # チャンネル出力
    14                                     5,  # カーネルサイズ
    15                                     1,  # ストライド (デフォルトは1)
    16                                     0,  # パディング (デフォルトは0)
    17                                     )
    18        self.conv2 = torch.nn.Conv2d(6, 16, 5)
    19
    20        self.pool = torch.nn.MaxPool2d(2, 2)  # カーネルサイズ, ストライド
    21
    22        self.dropout1 = torch.nn.Dropout2d(p=0.3)  # [new] Dropoutを追加してみる
    23
    24        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)  # 入力サイズ, 出力サイズ
    25        self.dropout2 = torch.nn.Dropout(p=0.5)  # [new] Dropoutを追加してみる
    26        self.fc2 = torch.nn.Linear(120, 84)
    27        self.fc3 = torch.nn.Linear(84, 10)
    28
    29    def forward(self, x):
    30        x = f.relu(self.conv1(x))
    31        x = self.pool(x)
    32        x = f.relu(self.conv2(x))
    33        x = self.pool(x)
    34        x = self.dropout1(x)  # [new] Dropoutを追加
    35        x = x.view(-1, 16 * 5 * 5)  # 1次元データに変えて全結合層へ
    36        x = f.relu(self.fc1(x))
    37        x = self.dropout2(x)   # [new] Dropoutを追加
    38        x = f.relu(self.fc2(x))
    39        x = self.fc3(x)
    40
    41        return x
    42
    43    def plot_conv1(self, prefix_num=0):
    44        weights1 = self.conv1.weight
    45        weights1 = weights1.reshape(3*6, 5, 5)
    46
    47        for i, weight in enumerate(weights1):
    48            plt.subplot(3, 6, i + 1)
    49            plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')
    50            plt.tick_params(labelbottom=False,
    51                            labelleft=False,
    52                            labelright=False,
    53                            labeltop=False,
    54                            bottom=False,
    55                            left=False,
    56                            right=False,
    57                            top=False)
    58
    59        plt.savefig('img/{}_conv1.png'.format(prefix_num))
    60        plt.close()
    61
    62    def plot_conv2(self, prefix_num=0):
    63        weights2 = self.conv2.weight
    64        weights2 = weights2.reshape(6*16, 5, 5)
    65
    66        for i, weight in enumerate(weights2):
    67            plt.subplot(6, 16, i + 1)
    68            plt.imshow(weight.data.to('cpu').numpy(), cmap='winter')
    69            plt.tick_params(labelbottom=False,
    70                            labelleft=False,
    71                            labelright=False,
    72                            labeltop=False,
    73                            bottom=False,
    74                            left=False,
    75                            right=False,
    76                            top=False)
    77
    78        plt.savefig('img/{}_conv2.png'.format(prefix_num))
    79        plt.close()
    80
    81
    82def load_cifar10(batch=128):
    83    train_loader = DataLoader(
    84        datasets.CIFAR10('./data',
    85                         train=True,
    86                         download=True,
    87                         transform=transforms.Compose([
    88                             transforms.ToTensor(),
    89                             transforms.Normalize(
    90                                [0.5, 0.5, 0.5],  # RGB 平均
    91                                [0.5, 0.5, 0.5]   # RGB 標準偏差
    92                                )
    93                         ])),
    94        batch_size=batch,
    95        shuffle=True
    96    )
    97
    98    test_loader = DataLoader(
    99        datasets.CIFAR10('./data',
    100                         train=False,
    101                         download=True,
    102                         transform=transforms.Compose([
    103                             transforms.ToTensor(),
    104                             transforms.Normalize(
    105                                 [0.5, 0.5, 0.5],  # RGB 平均
    106                                 [0.5, 0.5, 0.5]  # RGB 標準偏差
    107                             )
    108                         ])),
    109        batch_size=batch,
    110        shuffle=True
    111    )
    112
    113    return {'train': train_loader, 'test': test_loader}
    114
    115
    116if __name__ == '__main__':
    117    epoch = 300
    118
    119    loader = load_cifar10()
    120    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    121               'dog', 'frog', 'horse', 'ship', 'truck')
    122
    123    net: MyCNN = MyCNN()
    124    criterion = torch.nn.CrossEntropyLoss()  # ロスの計算
    125    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
    126
    127    # もしGPUが使えるなら使う
    128    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    129    net.to(device)
    130    print(device)
    131
    132    # 学習前のフィルタの可視化
    133    net.plot_conv1()
    134    net.plot_conv2()
    135
    136    history = {
    137        'train_loss': [],
    138        'train_acc': [],
    139        'test_acc': []
    140    }
    141
    142    for e in range(epoch):
    143        net.train()
    144        loss = None
    145        for i, (images, labels) in enumerate(loader['train']):
    146            images = images.to(device)  # to GPU?
    147            labels = labels.to(device)
    148
    149            optimizer.zero_grad()
    150            output = net(images)
    151            loss = criterion(output, labels)
    152            loss.backward()
    153            optimizer.step()
    154
    155            if i % 10 == 0:
    156                print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1,
    157                                                                                         (i + 1) * 128,
    158                                                                                         loss.item())
    159                      )
    160
    161        # 学習過程でのフィルタの可視化
    162        # net.plot_conv1(e+1)
    163        # net.plot_conv2(e+1)
    164
    165        history['train_loss'].append(loss.item())
    166
    167        net.eval()
    168        correct = 0
    169        with torch.no_grad():
    170            for i, (images, labels) in enumerate(tqdm(loader['train'])):
    171                images = images.to(device)  # to GPU?
    172                labels = labels.to(device)
    173
    174                outputs = net(images)
    175                _, predicted = torch.max(outputs.data, 1)
    176                correct += (predicted == labels).sum().item()
    177
    178        acc = float(correct / 50000)
    179        history['train_acc'].append(acc)
    180
    181        correct = 0
    182        with torch.no_grad():
    183            for i, (images, labels) in enumerate(tqdm(loader['test'])):
    184                images = images.to(device)  # to GPU?
    185                labels = labels.to(device)
    186
    187                outputs = net(images)
    188                _, predicted = torch.max(outputs.data, 1)
    189                correct += (predicted == labels).sum().item()
    190
    191        acc = float(correct / 10000)
    192        history['test_acc'].append(acc)
    193
    194    # 学習前のフィルタの可視化
    195    net.plot_conv1(300)
    196    net.plot_conv2(300)
    197
    198    # 結果をプロット
    199    plt.plot(range(1, epoch+1), history['train_loss'])
    200    plt.title('Training Loss [CIFAR10]')
    201    plt.xlabel('epoch')
    202    plt.ylabel('loss')
    203    plt.savefig('img/cifar10_loss.png')
    204    plt.close()
    205
    206    plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc')
    207    plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc')
    208    plt.title('Accuracies [CIFAR10]')
    209    plt.xlabel('epoch')
    210    plt.ylabel('accuracy')
    211    plt.legend()
    212    plt.savefig('img/cifar10_acc.png')
    213    plt.close()

    前編の記事はこちら

    featureImg2020.02.07【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】PyTorchでCIFAR-10をCNNに学習させる前回の『【PyTorch入門】PyTorchで手書き数字(MNIS...

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    関連記事

    featureImg2020.01.23【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させるPyTorchで手書き数字(MNIST)を学習させる前回は、PyTorch(パイトーチ)のインストールなどを行いました...

    ライトコードでは、エンジニアを積極採用中!

    ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。

    採用情報へ

    広告メディア事業部

    広告メディア事業部

    おすすめ記事

    エンジニア大募集中!

    ライトコードでは、エンジニアを積極採用中です。

    特に、WEBエンジニアとモバイルエンジニアは是非ご応募お待ちしております!

    また、フリーランスエンジニア様も大募集中です。

    background