1. HOME
  2. ブログ
  3. IT技術
  4. 【後編】「Keras」と「PyTorch」を徹底比較してみた!~データローダ・転移学習編~

【後編】「Keras」と「PyTorch」を徹底比較してみた!~データローダ・転移学習編~

【後編】Keras と PyTorch を比較したい!

前回は、基本的な「Keras」と「PyTorch」の違いについて比較してみました。

今回は、実際にサンプルをカスタムして自前データを学習させる事を想定して、Keras と PyTorch を比較していきたいと思います!

自前データを学習させるために、「どの部分」を「どのくらいカスタム」しなければいけないかを検討していきます!

前編はこちら

「データローダ」を比較!

「データローダ」は、実際のデータをニューラルネットに供給する役目を果たします。

有名どころのデータセットは、「Keras」も「PyTorch」もビルトインのデータセットを持っています。

どちらも1行書くだけで OK です。

また、ローカルになければ、オリジナルからダウンロードして持ってきてくれるという便利な機能があります。

実際の読み込みは、以下のような感じになります。

Kerasの場合

PyTorchの場合

どちらも同じようなシンタックスになっています。

ただ、中身を詳細に見ていくと、PyTorch の方が対応しているデータセットが多く、データセットの豊富さで PyTorch を選ぶ方も多いかと思います。

カスタムデータローダを比較!

さて、今回の話の目玉はここからです!

まずは、「カスタムデータローダ」について簡単に説明したいと思います。

自前のデータを学習させてみたいという場合、カスタムデータローダは必須の機能となります。

まず、必須機能である、ディレクトリに置かれたデータを読みだし、供給してくれる機能を説明していきます。

これは、「Torchvision」の説明からの図になります。

(Torchvision、Keras、どちらも同じディレクトリを対象としています。)

root の下に、クラスラベルにあたるサブディレクトリを作ります。

そして、作成したサブディレクトリの下に、対象クラスの画像を入れるという形式になります。

ここで、どちらも Label は「One hot ベクター形式」となります。

そのため、このベクターの出現順序は、サブディレクトリ名をソートしたものを「0」から並べたものになります。

Keras で実装

まずは、Keras での実装を見てみたいと思います!

実際のデータ生成は、 ImageDataGenerator が行います。

ImageDataGenerator は、データオーギュメンテーション(データ拡張)のための、様々な機能が実装されているクラスです。

ここで、 flow_from_directory がディレクトリからの読み出しの機能を持っています。

データオーギュメンテーションの細かい実装では、Keras の方が高機能です。

ベースネットで優秀なものを使う場合には、Keras もなかなか使いやすいように思います。

ただ、高機能ではありますが、使用する際に難しい機能も多く、一概に何とも言えないところもあります。

実際の研究や実務では、ネットワークをいじることは初期の段階でしか行いません。

データをいじることが大部分というパターンが多く、いろいろなデータを供給できるようにするのは合理的です。

PyTorch で実装

続いて、PyTorch での実装を見ていきたいと思います!

torchvision を使った実装です。

torch.utils.data での実装もできますが、torchvision の方がはるかに楽です。

torchvision はとても強力なライブラリで、torchvision のコア機能のひとつが「transform」です。

「transform」機能とは?
transform は、データの形式を変形していく機能です。
インポート形式が画像の場合、「リサイズして、画像を読み込んで、画素を正規化して入力する」というような形がニューラルネットでは一般的です。
transform では、この一連の流れを配列の中に列挙することで、一括して処理をしてくれます。

このあたりの記述については、「PyTorch が圧勝」です。

Keras の場合だと、かなり習熟した人がプログラムを書き込んでいかないと、できない内容です。

また、前回の記事でも解説したように、デバッグのやりやすさでも PyTorch が優れています。

いざ、自前のデータをいろいろ加工して学習させようという場合には、PyTorch に軍配が上がります。

「転移学習」を比較!

次に、「転移学習」について、Keras と PyTorch での違いを調べていきましょう!

「転移学習」とは?
大きなデータセットを使って、大きな計算機で十分に学習したモデルを特徴量の抽出器として使い、最後に、自分のデータを学習させて、結果を得るという方法。
画像のクラシフィケーションやセグメンテーションなどでは、デフォルトの手法となっています。

試しに、「ResNet」を元ネタ(ネットワーク)として、自分のクラスを上記のデータローダを利用して学習させてみましょう!

Keras で実装

まずは、Keras から。

Keras の場合は、以下のようなコードとなります。

img_widthimg_height には「224」を入れて下さい。

nb_classes には自分の分類したいクラス数を書きます。

これだけの変更で動作します。

どうでしょうか?簡単ですね!

今や、クラシフィケーションタスクであれば、転移学習を使った方が結果がでます。

自分でネットを書くよりも簡単に、そして、確実に動作させることができます。

PyTorch で実装

続いて、PyTorch!

PyTorch の場合は、以下のようなコードとなります。

PyTorch のチュートリアルにあるもので、実際に動作したので、これを掲載しました。

Keras と同様に nb_classes に、自分のクラス数を入れることが必要です。

コードは「Cuda」対応となっています。

転移学習をさせる場合には、ネットが大きくなるので、GPU を使わないとツライと思います。

PyTorch のチュートリアルには、転移学習でファインチューニングをやるやり方が掲載されています。

実際に使用する場合には、チェックする事をオススメします!

さいごに

実践的に、自前データを回す場合に必要となる「データローダ」と、結果を出すために一番の近道である「転移学習」。

今回は、その2つの機能を、「Keras」と「PyTorch」で比較してみました!

どちらでも基本機能に差はないので、自分にとって使いやすい方を選ぶのがいいと思います。

「機械学習を学んでみたい!」「機械学習初心者だけど、仕事で実務を回したい!」というような方には Keras

「バリバリプログラムを書くぞ!」というような方や、元々、他のエリアで十分に経験を積み重ねているようなプログラマであれば、PyTorch が良い選択だと思います!

ちなみに、先日、PyTorch は「PyTorch 1.5」がリリースされました。

「Facebook」と「Microsoft」という、2大巨頭をスポンサーにしている PyTorch は、強力に開発が進んでいます。

ライブラリも拡充されて、今後、どんどん使いやすいものになっていくものと思われます。

比較的新しいネットワークもどんどん拡充されていきますので、「将来性からいうと PyTorch」といった感じもあります。

初学者の方はもちろんのこと、「機械学習にどっぷりだけど、Keras か PytTorch にはまだ触ってない」という方も、ぜひこの機会に、Keras と PyTorch に触れてみましょう!

こちらの記事もオススメ!

関連記事

書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア