• トップ
  • ブログ一覧
  • 【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
  • 【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

    広告メディア事業部広告メディア事業部
    2020.02.07

    IT技術

    PyTorchでCIFAR-10をCNNに学習させる

    前回の『【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる』に引き続き、PyTorchで機械学習を学んでいきましょう!

    今回は、PyTorchで畳み込みニューラルネットワーク(CNN)を実装していきます。

    ちなみに、公式ドキュメントにも同じような実装が紹介されているようです。

    ですが、本記事では、日本語で分かりやすく詳細に解説していきたいと思っています。

    さらに最後には、ネットワークの内部を可視化してみたり、GPUを使用してみたりと、様々な実験が含まれている記事になっています!

    ですので、ぜひ最後まで読んでみてください!

    インポートされているモジュール

    これから実装するコードは、以下のモジュールがあらかじめインポートされています。

    1import torch
    2import torch.nn.functional as f
    3from torch.utils.data import DataLoader
    4from torchvision import datasets, transforms
    5import matplotlib.pyplot as plt
    6from tqdm import tqdm

    それでは、実際に実装していきましょう!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    CIFAR10の準備

    今回用いるデータセットの「CIFAR10 (サイファー10)」 は、32×32 のカラー画像からなるデータセットで、その名の通り10クラスあります。

    「MNIST」は 28×28 のグレースケール画像なので、「CIFER10」の方が情報量は数倍多く、学習は難しいです。

    畳み込みニューラルネットワークを学んだり、研究しているほとんどの人は、このデータセットをカラー画像の簡易なベンチマークとして使用しています。

    それでは、このデータセットを読み込む関数を作っていきます。

    CIFER10を読み込む関数を作る

    1def load_cifar10(batch=128):
    2    train_loader = DataLoader(
    3        datasets.CIFAR10('./data',
    4                         train=True,
    5                         download=True,
    6                         transform=transforms.Compose([
    7                             transforms.ToTensor(),
    8                             transforms.Normalize(
    9                                [0.5, 0.5, 0.5],  # RGB 平均
    10                                [0.5, 0.5, 0.5]   # RGB 標準偏差
    11                                )
    12                         ])),
    13        batch_size=batch,
    14        shuffle=True
    15    )
    16
    17    test_loader = DataLoader(
    18        datasets.CIFAR10('./data',
    19                         train=False,
    20                         download=True,
    21                         transform=transforms.Compose([
    22                             transforms.ToTensor(),
    23                             transforms.Normalize(
    24                                 [0.5, 0.5, 0.5],  # RGB 平均
    25                                 [0.5, 0.5, 0.5]  # RGB 標準偏差
    26                             )
    27                         ])),
    28        batch_size=batch,
    29        shuffle=True
    30    )
    31
    32    return {'train': train_loader, 'test': test_loader}

    「MNIST」で行った時とほとんど変わりませんが、今回データの正規化として、各カラーチャネルの平均と標準偏差を「0.5」になるようにしています。

    画像データセットでよくある「正規化」ですので、覚えておきましょう!

    動作確認

    それでは、試しに動作確認をしてみましょう!

    1if __name__ == '__main__':
    2    loader = load_cifar10()
    3    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    4               'dog', 'frog', 'horse', 'ship', 'truck')  # CIFAR10のクラス
    5
    6    for images, labels in loader['train']:
    7        print(images.shape)  # torch.Size([128, 3, 32, 32])
    8
    9        # 試しに50枚を 5x10 で見てみる
    10        for i in range(5):
    11            for j in range(10):
    12                image = images[i*10+j] / 2 + 0.5
    13                image = image.numpy()
    14                plt.subplot(5, 10, i*10+j + 1)
    15                plt.imshow(np.transpose(image, (1, 2, 0)))  # matplotlibではチャネルは第3次元
    16                
    17                # 対応するラベル
    18                plt.title(classes[int(labels[i*10+j])])
    19                
    20                # 軸目盛や値はいらないので消す
    21                plt.tick_params(labelbottom=False,
    22                                labelleft=False,
    23                                labelright=False,
    24                                labeltop=False,
    25                                bottom=False,
    26                                left=False,
    27                                right=False,
    28                                top=False)
    29
    30        plt.show()
    31        break

    画像の描画はイテレーションループを使ってますが、もちろん iter() でもOKです!

    PyTorch のローダーを使って取得したデータセットは、イテレータで取得したときに [バッチサイズ, チャネル, (画像のシェイプ)]というテンソル型の画像と、[バッチサイズ] というテンソル型のラベルを返します。

    描画結果

    実際に50枚描画してみると、以下のようになりました。

    CIFAR10の例

    32×32 なので粗い画像ですが、画像とラベルが一致していそうですね。

    これらの画像を、今から畳み込みニューラルネットワーク(CNN)に学習させていきます!

    CNNの構築

    それでは、早速ネットワークを構築していきます。

    今回は以下のような、「LeNet」 と呼ばれる畳み込みニューラルネットワーク(CNN)をベースに構築し、学習させていきます。

    「LeNet」の構成

    「LeNet」が提案されたのは1998年と古いものですが、畳み込みニューラルネットワーク(CNN)という名を有名にさせたネットワークです。

    LeNetの構成

    このネットワークを、ほとんどそのまま実装してみると以下のようになります。

    (実際には活性化関数など、一部元論文と異なります)

    実装

    1class MyCNN(torch.nn.Module):
    2    def __init__(self):
    3        super(MyCNN, self).__init__()
    4        self.conv1 = torch.nn.Conv2d(3,  # チャネル入力
    5                                     6,  # チャンネル出力
    6                                     5,  # カーネルサイズ
    7                                     1,  # ストライド (デフォルトは1)
    8                                     0,  # パディング (デフォルトは0)
    9                                     )
    10        self.conv2 = torch.nn.Conv2d(6, 16, 5)
    11
    12        self.pool = torch.nn.MaxPool2d(2, 2)  # カーネルサイズ, ストライド
    13
    14        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)  # 入力サイズ, 出力サイズ
    15        self.fc2 = torch.nn.Linear(120, 84)
    16        self.fc3 = torch.nn.Linear(84, 10)
    17
    18    def forward(self, x):
    19        x = f.relu(self.conv1(x))
    20        x = self.pool(x)
    21        x = f.relu(self.conv2(x))
    22        x = self.pool(x)
    23        x = x.view(-1, 16 * 5 * 5)  # 1次元データに変えて全結合層へ
    24        x = f.relu(self.fc1(x))
    25        x = f.relu(self.fc2(x))
    26        x = self.fc3(x)
    27
    28        return x

    ネットワーク名は MyCNN としました。

    通常の画像を畳み込む場合、torch.nn.Conv2d を用いますが引数についてはコメントのとおりです。

    構築の仕方は「MNIST」の時とほとんど同じなので分かりやすいかと思います。

    このとき、(入力チャネル)×(出力チャネル)が畳み込みフィルタの数になり、これらはネットワークが構築された段階でランダムに初期化されます。

    この畳み込みフィルタは、『学習の過程でどう変化していくのか』を観察する予定です。

    畳み込みフィルタを可視化する関数をつくる

    では最初に、可視化用の関数をつくっていきましょう!

    先ほどの MyCNN クラスのメンバ関数でOKですので、以下の関数を加筆してください。

    1    def plot_conv1(self, prefix_num=0):
    2        weights1 = self.conv1.weight
    3        weights1 = weights1.reshape(3*6, 5, 5)
    4
    5        for i, weight in enumerate(weights1):
    6            plt.subplot(3, 6, i + 1)
    7            plt.imshow(weight.data.numpy(), cmap='winter')
    8            plt.tick_params(labelbottom=False,
    9                            labelleft=False,
    10                            labelright=False,
    11                            labeltop=False,
    12                            bottom=False,
    13                            left=False,
    14                            right=False,
    15                            top=False)
    16
    17        plt.savefig('img/{}_conv1.png'.format(prefix_num))
    18        plt.close()
    19
    20    def plot_conv2(self, prefix_num=0):
    21        weights2 = self.conv2.weight
    22        weights2 = weights2.reshape(6*16, 5, 5)
    23
    24        for i, weight in enumerate(weights2):
    25            plt.subplot(6, 16, i + 1)
    26            plt.imshow(weight.data.numpy(), cmap='winter')
    27            plt.tick_params(labelbottom=False,
    28                            labelleft=False,
    29                            labelright=False,
    30                            labeltop=False,
    31                            bottom=False,
    32                            left=False,
    33                            right=False,
    34                            top=False)
    35
    36        plt.savefig('img/{}_conv2.png'.format(prefix_num))
    37        plt.close()

    各レイヤーの重み情報は weight というメンバ変数が保持しています。

    これは単純な重み情報だけでなく、勾配情報やデバイス情報(CPU or GPU)などを保持しているので、純粋な重みを取り出す場合 weight.data とします。

    あとは、先ほどの「CIFAR10」の可視化と、ほとんど一緒ですね。

    ちなみに、学習前の重みはこんな感じです。

    学習前のconv1
    学習前のconv2

    ただ、ランダムなので、まだ何がなんだかよくわかりませんね(笑)

    学習部を作る

    それでは、学習部を実装していきます。

    これも「MNIST」の時とほとんど同様ですが、今回は損失関数として「クロスエントロピー」を用います。

    メイン処理部分の、ネットワーク構築から訓練までは以下のようにしました。

    1if __name__ == '__main__':
    2    epoch = 50
    3
    4    loader = load_cifar10()
    5    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    6               'dog', 'frog', 'horse', 'ship', 'truck')
    7
    8    net: MyCNN = MyCNN()
    9    criterion = torch.nn.CrossEntropyLoss()  # ロスの計算
    10    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
    11
    12    # 学習前のフィルタの可視化
    13    net.plot_conv1()
    14    net.plot_conv2()
    15
    16    history = {
    17        'train_loss': [],
    18        'train_acc': [],
    19        'test_acc': []
    20    }
    21
    22    for e in range(epoch):
    23        net.train()
    24        loss = None
    25        for i, (images, labels) in enumerate(loader['train']):
    26            optimizer.zero_grad()
    27            output = net(images)
    28            loss = criterion(output, labels)
    29            loss.backward()
    30            optimizer.step()
    31
    32            if i % 10 == 0:
    33                print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1,
    34                                                                                         (i + 1) * 128,
    35                                                                                         loss.item())
    36                      )
    37
    38        # 学習過程でのフィルタの可視化
    39        net.plot_conv1(e+1)
    40        net.plot_conv2(e+1)
    41
    42        history['train_loss'].append(loss.item())

    ひとまず、学習は50エポック分行いたいと思います。

    最終的な結果として、訓練ロスと訓練精度、そしてテスト精度の変化が得られるようにします。

    それらは history という名の、辞書型変数に格納する形になっています。

    テスト部分

    次にテスト部分ですが、これも「MNIST」の時と大差はありません。

    今回は、訓練精度とテスト精度を見たいので、2つのループがあります。

    実装

    テスト部分から最後の結果を描画するまでは、以下のように実装しました。

    1        net.eval()
    2        correct = 0
    3        with torch.no_grad():
    4            for i, (images, labels) in enumerate(tqdm(loader['train'])):
    5                outputs = net(images)
    6                _, predicted = torch.max(outputs.data, 1)
    7                correct += (predicted == labels).sum().item()
    8
    9        acc = float(correct / 50000)
    10        history['train_acc'].append(acc)
    11
    12        correct = 0
    13        with torch.no_grad():
    14            for i, (images, labels) in enumerate(tqdm(loader['test'])):
    15                outputs = net(images)
    16                _, predicted = torch.max(outputs.data, 1)
    17                correct += (predicted == labels).sum().item()
    18
    19        acc = float(correct / 10000)
    20        history['test_acc'].append(acc)
    21
    22    # 結果をプロット
    23    plt.plot(range(1, epoch+1), history['train_loss'])
    24    plt.title('Training Loss [CIFAR10]')
    25    plt.xlabel('epoch')
    26    plt.ylabel('loss')
    27    plt.savefig('img/cifar10_loss.png')
    28    plt.close()
    29
    30    plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc')
    31    plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc')
    32    plt.title('Accuracies [CIFAR10]')
    33    plt.xlabel('epoch')
    34    plt.ylabel('accuracy')
    35    plt.legend()
    36    plt.savefig('img/cifar10_acc.png')
    37    plt.close()

    それでは、早速学習させてみます!

    結果が気になるところですが…今回はここまで!

    次回は、実際に学習させてみたり、GPUを使ってみたりしたいと思いますのでお楽しみに!

    後編はこちら

    featureImg2020.02.13【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】【後編】PyTorchでCIFAR-10をCNNに学習させる【前編】の続きとなります。引き続き、PyTorch(パイト...

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    関連記事

    featureImg2020.01.23【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させるPyTorchで手書き数字(MNIST)を学習させる前回は、PyTorch(パイトーチ)のインストールなどを行いました...

    広告メディア事業部

    広告メディア事業部

    おすすめ記事