• トップ
  • ブログ一覧
  • 【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】
  • 【前編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】

    広告メディア事業部広告メディア事業部
    2020.02.07

    IT技術

    PyTorchでCIFAR-10をCNNに学習させる

    前回の『【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させる』に引き続き、PyTorchで機械学習を学んでいきましょう!

    今回は、PyTorchで畳み込みニューラルネットワーク(CNN)を実装していきます。

    ちなみに、公式ドキュメントにも同じような実装が紹介されているようです。

    ですが、本記事では、日本語で分かりやすく詳細に解説していきたいと思っています。

    さらに最後には、ネットワークの内部を可視化してみたり、GPUを使用してみたりと、様々な実験が含まれている記事になっています!

    ですので、ぜひ最後まで読んでみてください!

    インポートされているモジュール

    これから実装するコードは、以下のモジュールがあらかじめインポートされています。

    1import torch
    2import torch.nn.functional as f
    3from torch.utils.data import DataLoader
    4from torchvision import datasets, transforms
    5import matplotlib.pyplot as plt
    6from tqdm import tqdm

    それでは、実際に実装していきましょう!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    CIFAR10の準備

    今回用いるデータセットの「CIFAR10 (サイファー10)」 は、32×32 のカラー画像からなるデータセットで、その名の通り10クラスあります。

    「MNIST」は 28×28 のグレースケール画像なので、「CIFER10」の方が情報量は数倍多く、学習は難しいです。

    畳み込みニューラルネットワークを学んだり、研究しているほとんどの人は、このデータセットをカラー画像の簡易なベンチマークとして使用しています。

    それでは、このデータセットを読み込む関数を作っていきます。

    CIFER10を読み込む関数を作る

    1def load_cifar10(batch=128):
    2    train_loader = DataLoader(
    3        datasets.CIFAR10('./data',
    4                         train=True,
    5                         download=True,
    6                         transform=transforms.Compose([
    7                             transforms.ToTensor(),
    8                             transforms.Normalize(
    9                                [0.5, 0.5, 0.5],  # RGB 平均
    10                                [0.5, 0.5, 0.5]   # RGB 標準偏差
    11                                )
    12                         ])),
    13        batch_size=batch,
    14        shuffle=True
    15    )
    16
    17    test_loader = DataLoader(
    18        datasets.CIFAR10('./data',
    19                         train=False,
    20                         download=True,
    21                         transform=transforms.Compose([
    22                             transforms.ToTensor(),
    23                             transforms.Normalize(
    24                                 [0.5, 0.5, 0.5],  # RGB 平均
    25                                 [0.5, 0.5, 0.5]  # RGB 標準偏差
    26                             )
    27                         ])),
    28        batch_size=batch,
    29        shuffle=True
    30    )
    31
    32    return {'train': train_loader, 'test': test_loader}

    「MNIST」で行った時とほとんど変わりませんが、今回データの正規化として、各カラーチャネルの平均と標準偏差を「0.5」になるようにしています。

    画像データセットでよくある「正規化」ですので、覚えておきましょう!

    動作確認

    それでは、試しに動作確認をしてみましょう!

    1if __name__ == '__main__':
    2    loader = load_cifar10()
    3    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    4               'dog', 'frog', 'horse', 'ship', 'truck')  # CIFAR10のクラス
    5
    6    for images, labels in loader['train']:
    7        print(images.shape)  # torch.Size([128, 3, 32, 32])
    8
    9        # 試しに50枚を 5x10 で見てみる
    10        for i in range(5):
    11            for j in range(10):
    12                image = images[i*10+j] / 2 + 0.5
    13                image = image.numpy()
    14                plt.subplot(5, 10, i*10+j + 1)
    15                plt.imshow(np.transpose(image, (1, 2, 0)))  # matplotlibではチャネルは第3次元
    16                
    17                # 対応するラベル
    18                plt.title(classes[int(labels[i*10+j])])
    19                
    20                # 軸目盛や値はいらないので消す
    21                plt.tick_params(labelbottom=False,
    22                                labelleft=False,
    23                                labelright=False,
    24                                labeltop=False,
    25                                bottom=False,
    26                                left=False,
    27                                right=False,
    28                                top=False)
    29
    30        plt.show()
    31        break

    画像の描画はイテレーションループを使ってますが、もちろん iter() でもOKです!

    PyTorch のローダーを使って取得したデータセットは、イテレータで取得したときに [バッチサイズ, チャネル, (画像のシェイプ)]というテンソル型の画像と、[バッチサイズ] というテンソル型のラベルを返します。

    描画結果

    実際に50枚描画してみると、以下のようになりました。

    CIFAR10の例

    32×32 なので粗い画像ですが、画像とラベルが一致していそうですね。

    これらの画像を、今から畳み込みニューラルネットワーク(CNN)に学習させていきます!

    CNNの構築

    それでは、早速ネットワークを構築していきます。

    今回は以下のような、「LeNet」 と呼ばれる畳み込みニューラルネットワーク(CNN)をベースに構築し、学習させていきます。

    「LeNet」の構成

    「LeNet」が提案されたのは1998年と古いものですが、畳み込みニューラルネットワーク(CNN)という名を有名にさせたネットワークです。

    LeNetの構成

    このネットワークを、ほとんどそのまま実装してみると以下のようになります。

    (実際には活性化関数など、一部元論文と異なります)

    実装

    1class MyCNN(torch.nn.Module):
    2    def __init__(self):
    3        super(MyCNN, self).__init__()
    4        self.conv1 = torch.nn.Conv2d(3,  # チャネル入力
    5                                     6,  # チャンネル出力
    6                                     5,  # カーネルサイズ
    7                                     1,  # ストライド (デフォルトは1)
    8                                     0,  # パディング (デフォルトは0)
    9                                     )
    10        self.conv2 = torch.nn.Conv2d(6, 16, 5)
    11
    12        self.pool = torch.nn.MaxPool2d(2, 2)  # カーネルサイズ, ストライド
    13
    14        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)  # 入力サイズ, 出力サイズ
    15        self.fc2 = torch.nn.Linear(120, 84)
    16        self.fc3 = torch.nn.Linear(84, 10)
    17
    18    def forward(self, x):
    19        x = f.relu(self.conv1(x))
    20        x = self.pool(x)
    21        x = f.relu(self.conv2(x))
    22        x = self.pool(x)
    23        x = x.view(-1, 16 * 5 * 5)  # 1次元データに変えて全結合層へ
    24        x = f.relu(self.fc1(x))
    25        x = f.relu(self.fc2(x))
    26        x = self.fc3(x)
    27
    28        return x

    ネットワーク名は MyCNN としました。

    通常の画像を畳み込む場合、torch.nn.Conv2d を用いますが引数についてはコメントのとおりです。

    構築の仕方は「MNIST」の時とほとんど同じなので分かりやすいかと思います。

    このとき、(入力チャネル)×(出力チャネル)が畳み込みフィルタの数になり、これらはネットワークが構築された段階でランダムに初期化されます。

    この畳み込みフィルタは、『学習の過程でどう変化していくのか』を観察する予定です。

    畳み込みフィルタを可視化する関数をつくる

    では最初に、可視化用の関数をつくっていきましょう!

    先ほどの MyCNN クラスのメンバ関数でOKですので、以下の関数を加筆してください。

    1    def plot_conv1(self, prefix_num=0):
    2        weights1 = self.conv1.weight
    3        weights1 = weights1.reshape(3*6, 5, 5)
    4
    5        for i, weight in enumerate(weights1):
    6            plt.subplot(3, 6, i + 1)
    7            plt.imshow(weight.data.numpy(), cmap='winter')
    8            plt.tick_params(labelbottom=False,
    9                            labelleft=False,
    10                            labelright=False,
    11                            labeltop=False,
    12                            bottom=False,
    13                            left=False,
    14                            right=False,
    15                            top=False)
    16
    17        plt.savefig('img/{}_conv1.png'.format(prefix_num))
    18        plt.close()
    19
    20    def plot_conv2(self, prefix_num=0):
    21        weights2 = self.conv2.weight
    22        weights2 = weights2.reshape(6*16, 5, 5)
    23
    24        for i, weight in enumerate(weights2):
    25            plt.subplot(6, 16, i + 1)
    26            plt.imshow(weight.data.numpy(), cmap='winter')
    27            plt.tick_params(labelbottom=False,
    28                            labelleft=False,
    29                            labelright=False,
    30                            labeltop=False,
    31                            bottom=False,
    32                            left=False,
    33                            right=False,
    34                            top=False)
    35
    36        plt.savefig('img/{}_conv2.png'.format(prefix_num))
    37        plt.close()

    各レイヤーの重み情報は weight というメンバ変数が保持しています。

    これは単純な重み情報だけでなく、勾配情報やデバイス情報(CPU or GPU)などを保持しているので、純粋な重みを取り出す場合 weight.data とします。

    あとは、先ほどの「CIFAR10」の可視化と、ほとんど一緒ですね。

    ちなみに、学習前の重みはこんな感じです。

    学習前のconv1
    学習前のconv2

    ただ、ランダムなので、まだ何がなんだかよくわかりませんね(笑)

    学習部を作る

    それでは、学習部を実装していきます。

    これも「MNIST」の時とほとんど同様ですが、今回は損失関数として「クロスエントロピー」を用います。

    メイン処理部分の、ネットワーク構築から訓練までは以下のようにしました。

    1if __name__ == '__main__':
    2    epoch = 50
    3
    4    loader = load_cifar10()
    5    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    6               'dog', 'frog', 'horse', 'ship', 'truck')
    7
    8    net: MyCNN = MyCNN()
    9    criterion = torch.nn.CrossEntropyLoss()  # ロスの計算
    10    optimizer = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
    11
    12    # 学習前のフィルタの可視化
    13    net.plot_conv1()
    14    net.plot_conv2()
    15
    16    history = {
    17        'train_loss': [],
    18        'train_acc': [],
    19        'test_acc': []
    20    }
    21
    22    for e in range(epoch):
    23        net.train()
    24        loss = None
    25        for i, (images, labels) in enumerate(loader['train']):
    26            optimizer.zero_grad()
    27            output = net(images)
    28            loss = criterion(output, labels)
    29            loss.backward()
    30            optimizer.step()
    31
    32            if i % 10 == 0:
    33                print('Training log: {} epoch ({} / 50000 train. data). Loss: {}'.format(e + 1,
    34                                                                                         (i + 1) * 128,
    35                                                                                         loss.item())
    36                      )
    37
    38        # 学習過程でのフィルタの可視化
    39        net.plot_conv1(e+1)
    40        net.plot_conv2(e+1)
    41
    42        history['train_loss'].append(loss.item())

    ひとまず、学習は50エポック分行いたいと思います。

    最終的な結果として、訓練ロスと訓練精度、そしてテスト精度の変化が得られるようにします。

    それらは history という名の、辞書型変数に格納する形になっています。

    テスト部分

    次にテスト部分ですが、これも「MNIST」の時と大差はありません。

    今回は、訓練精度とテスト精度を見たいので、2つのループがあります。

    実装

    テスト部分から最後の結果を描画するまでは、以下のように実装しました。

    1        net.eval()
    2        correct = 0
    3        with torch.no_grad():
    4            for i, (images, labels) in enumerate(tqdm(loader['train'])):
    5                outputs = net(images)
    6                _, predicted = torch.max(outputs.data, 1)
    7                correct += (predicted == labels).sum().item()
    8
    9        acc = float(correct / 50000)
    10        history['train_acc'].append(acc)
    11
    12        correct = 0
    13        with torch.no_grad():
    14            for i, (images, labels) in enumerate(tqdm(loader['test'])):
    15                outputs = net(images)
    16                _, predicted = torch.max(outputs.data, 1)
    17                correct += (predicted == labels).sum().item()
    18
    19        acc = float(correct / 10000)
    20        history['test_acc'].append(acc)
    21
    22    # 結果をプロット
    23    plt.plot(range(1, epoch+1), history['train_loss'])
    24    plt.title('Training Loss [CIFAR10]')
    25    plt.xlabel('epoch')
    26    plt.ylabel('loss')
    27    plt.savefig('img/cifar10_loss.png')
    28    plt.close()
    29
    30    plt.plot(range(1, epoch + 1), history['train_acc'], label='train_acc')
    31    plt.plot(range(1, epoch + 1), history['test_acc'], label='test_acc')
    32    plt.title('Accuracies [CIFAR10]')
    33    plt.xlabel('epoch')
    34    plt.ylabel('accuracy')
    35    plt.legend()
    36    plt.savefig('img/cifar10_acc.png')
    37    plt.close()

    それでは、早速学習させてみます!

    結果が気になるところですが…今回はここまで!

    次回は、実際に学習させてみたり、GPUを使ってみたりしたいと思いますのでお楽しみに!

    後編はこちら

    featureImg2020.02.13【後編】PyTorchでCIFAR-10をCNNに学習させる【PyTorch基礎】【後編】PyTorchでCIFAR-10をCNNに学習させる【前編】の続きとなります。引き続き、PyTorch(パイト...

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    関連記事

    featureImg2020.01.23【PyTorch入門】PyTorchで手書き数字(MNIST)を学習させるPyTorchで手書き数字(MNIST)を学習させる前回は、PyTorch(パイトーチ)のインストールなどを行いました...

    ライトコードでは、エンジニアを積極採用中!

    ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。

    採用情報へ

    広告メディア事業部
    広告メディア事業部
    Show more...

    おすすめ記事

    エンジニア大募集中!

    ライトコードでは、エンジニアを積極採用中です。

    特に、WEBエンジニアとモバイルエンジニアは是非ご応募お待ちしております!

    また、フリーランスエンジニア様も大募集中です。

    background