• トップ
  • ブログ一覧
  • 【後編】「Keras」と「PyTorch」を徹底比較してみた!~データローダ・転移学習編~
  • 【後編】「Keras」と「PyTorch」を徹底比較してみた!~データローダ・転移学習編~

    広告メディア事業部広告メディア事業部
    2020.08.04

    IT技術

    【後編】Keras と PyTorch を比較したい!

    前回は、基本的な「Keras」と「PyTorch」の違いについて比較してみました。

    今回は、実際にサンプルをカスタムして自前データを学習させる事を想定して、Keras と PyTorch を比較していきたいと思います!

    自前データを学習させるために、「どの部分」を「どのくらいカスタム」しなければいけないかを検討していきます!

    前編はこちら

    featureImg2020.08.03【前編】「Keras」と「PyTorch」を徹底比較してみた!~MNIST編~【前編】Keras と PyTorch を比較したい!現在、Keras(ケラス)と PyTorch(パイトーチ)が、機...

    「データローダ」を比較!

    「データローダ」は、実際のデータをニューラルネットに供給する役目を果たします。

    有名どころのデータセットは、「Keras」も「PyTorch」もビルトインのデータセットを持っています。

    どちらも1行書くだけで OK です。

    また、ローカルになければ、オリジナルからダウンロードして持ってきてくれるという便利な機能があります。

    実際の読み込みは、以下のような感じになります。

    Kerasの場合

    1from keras.datasets import mnist
    2
    3dataset = mnist.load_data()

    PyTorchの場合

    1from tochvision import dataset
    2
    3dataset = datasets.MNIST()

    どちらも同じようなシンタックスになっています。

    ただ、中身を詳細に見ていくと、PyTorch の方が対応しているデータセットが多く、データセットの豊富さで PyTorch を選ぶ方も多いかと思います。

    カスタムデータローダを比較!

    さて、今回の話の目玉はここからです!

    まずは、「カスタムデータローダ」について簡単に説明したいと思います。

    自前のデータを学習させてみたいという場合、カスタムデータローダは必須の機能となります。

    まず、必須機能である、ディレクトリに置かれたデータを読みだし、供給してくれる機能を説明していきます。

    これは、「Torchvision」の説明からの図になります。

    (Torchvision、Keras、どちらも同じディレクトリを対象としています。)

    root の下に、クラスラベルにあたるサブディレクトリを作ります。

    そして、作成したサブディレクトリの下に、対象クラスの画像を入れるという形式になります。

    ここで、どちらも Label は「One hot ベクター形式」となります。

    そのため、このベクターの出現順序は、サブディレクトリ名をソートしたものを「0」から並べたものになります。

    Keras で実装

    まずは、Keras での実装を見てみたいと思います!

    1train_data_dir = './data_root'
    2train_datagen = ImageDataGenerator(
    3    rescale=1. / 255,
    4    shear_range=0.2,
    5    zoom_range=0.2,
    6    horizontal_flip=True)
    7train_generator = train_datagen.flow_from_directory(
    8    train_data_dir,
    9    target_size=(32, 32),
    10    batch_size=8,
    11    class_mode='categorical')

    実際のデータ生成は、ImageDataGenerator が行います。

    ImageDataGenerator は、データオーギュメンテーション(データ拡張)のための、様々な機能が実装されているクラスです。

    ここで、flow_from_directory がディレクトリからの読み出しの機能を持っています。

    データオーギュメンテーションの細かい実装では、Keras の方が高機能です。

    ベースネットで優秀なものを使う場合には、Keras もなかなか使いやすいように思います。

    ただ、高機能ではありますが、使用する際に難しい機能も多く、一概に何とも言えないところもあります。

    実際の研究や実務では、ネットワークをいじることは初期の段階でしか行いません。

    データをいじることが大部分というパターンが多く、いろいろなデータを供給できるようにするのは合理的です。

    PyTorch で実装

    続いて、PyTorch での実装を見ていきたいと思います!

    1transform = transforms.Compose([transforms.Resize([224,224]),
    2                                transforms.RandomHorizontalFlip(),
    3                                transforms.ToTensor()],
    4                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    5
    6myDataset = datasets.ImageFolder('./data_root',transform =transform)
    7data_loader = torch.utils.data.DataLoader(myDataset,batch_size=16,shuffle=True)

    torchvision を使った実装です。

    torch.utils.data での実装もできますが、torchvision の方がはるかに楽です。

    torchvision はとても強力なライブラリで、torchvision のコア機能のひとつが「transform」です。

    「transform」機能とは?
    transform は、データの形式を変形していく機能です。
    インポート形式が画像の場合、「リサイズして、画像を読み込んで、画素を正規化して入力する」というような形がニューラルネットでは一般的です。
    transform では、この一連の流れを配列の中に列挙することで、一括して処理をしてくれます。

    このあたりの記述については、「PyTorch が圧勝」です。

    Keras の場合だと、かなり習熟した人がプログラムを書き込んでいかないと、できない内容です。

    また、前回の記事でも解説したように、デバッグのやりやすさでも PyTorch が優れています。

    いざ、自前のデータをいろいろ加工して学習させようという場合には、PyTorch に軍配が上がります。

    「転移学習」を比較!

    次に、「転移学習」について、Keras と PyTorch での違いを調べていきましょう!

    「転移学習」とは?
    大きなデータセットを使って、大きな計算機で十分に学習したモデルを特徴量の抽出器として使い、最後に、自分のデータを学習させて、結果を得るという方法。
    画像のクラシフィケーションやセグメンテーションなどでは、デフォルトの手法となっています。

    試しに、「ResNet」を元ネタ(ネットワーク)として、自分のクラスを上記のデータローダを利用して学習させてみましょう!

    Keras で実装

    まずは、Keras から。

    Keras の場合は、以下のようなコードとなります。

    1from keras.applications.resnet50 import ResNet50
    2from keras.models import Sequential, Model
    3from keras.layers import Input, Flatten, Dense
    4
    5input_tensor = Input(shape=(img_width, img_height, 3))
    6ResNet50 = ResNet50(include_top=False, weights='imagenet',input_tensor=input_tensor)
    7
    8top_model = Sequential()
    9top_model.add(Flatten(input_shape=ResNet50.output_shape[1:]))
    10top_model.add(Dense(nb_classes, activation='softmax'))
    11
    12model = Model(input=ResNet50.input, output=top_model(ResNet50.output))

    img_widthimg_height には「224」を入れて下さい。

    nb_classes には自分の分類したいクラス数を書きます。

    これだけの変更で動作します。

    どうでしょうか?簡単ですね!

    今や、クラシフィケーションタスクであれば、転移学習を使った方が結果がでます。

    自分でネットを書くよりも簡単に、そして、確実に動作させることができます。

    PyTorch で実装

    続いて、PyTorch!

    PyTorch の場合は、以下のようなコードとなります。

    1import torch.nn as nn
    2import torch.optim as optim
    3from torch.optim import lr_scheduler
    4
    5no_cuda = False
    6use_cuda = not no_cuda and torch.cuda.is_available()
    7device = torch.device("cuda" if use_cuda else "cpu")
    8
    9model_ft = models.resnet50(pretrained=True)
    10num_ftrs = model_ft.fc.in_features
    11model_ft.fc = nn.Linear(num_ftrs, nb_classes)
    12model_ft = model_ft.to(device)
    13
    14criterion = nn.CrossEntropyLoss()
    15optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

    PyTorch のチュートリアルにあるもので、実際に動作したので、これを掲載しました。

    Keras と同様に nb_classes に、自分のクラス数を入れることが必要です。

    コードは「Cuda」対応となっています。

    転移学習をさせる場合には、ネットが大きくなるので、GPU を使わないとツライと思います。

    PyTorch のチュートリアルには、転移学習でファインチューニングをやるやり方が掲載されています。

    実際に使用する場合には、チェックする事をオススメします!

    さいごに

    実践的に、自前データを回す場合に必要となる「データローダ」と、結果を出すために一番の近道である「転移学習」。

    今回は、その2つの機能を、「Keras」と「PyTorch」で比較してみました!

    どちらでも基本機能に差はないので、自分にとって使いやすい方を選ぶのがいいと思います。

    「機械学習を学んでみたい!」「機械学習初心者だけど、仕事で実務を回したい!」というような方には Keras

    「バリバリプログラムを書くぞ!」というような方や、元々、他のエリアで十分に経験を積み重ねているようなプログラマであれば、PyTorch が良い選択だと思います!

    ちなみに、先日、PyTorch は「PyTorch 1.5」がリリースされました。

    「Facebook」と「Microsoft」という、2大巨頭をスポンサーにしている PyTorch は、強力に開発が進んでいます。

    ライブラリも拡充されて、今後、どんどん使いやすいものになっていくものと思われます。

    比較的新しいネットワークもどんどん拡充されていきますので、「将来性からいうと PyTorch」といった感じもあります。

    初学者の方はもちろんのこと、「機械学習にどっぷりだけど、Keras か PytTorch にはまだ触ってない」という方も、ぜひこの機会に、Keras と PyTorch に触れてみましょう!

    こちらの記事もオススメ!

    featureImg2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...

    featureImg2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...

    関連記事

    featureImg2020.08.03【前編】「Keras」と「PyTorch」を徹底比較してみた!~MNIST編~【前編】Keras と PyTorch を比較したい!現在、Keras(ケラス)と PyTorch(パイトーチ)が、機...

    featureImg2020.08.04【後編】「Keras」と「PyTorch」を徹底比較してみた!~データローダ・転移学習編~【後編】Keras と PyTorch を比較したい!前回は、基本的な「Keras」と「PyTorch」の違いについて...

    広告メディア事業部

    広告メディア事業部

    おすすめ記事