1. HOME
  2. ブログ
  3. IT技術
  4. 【前編】「Keras」と「PyTorch」を徹底比較してみた!~MNIST編~

【前編】「Keras」と「PyTorch」を徹底比較してみた!~MNIST編~

【前編】Keras と PyTorch を比較したい!

現在、Keras(ケラス)と PyTorch(パイトーチ)が、機械学習フレームワークの中で人気 No 1、2を争っています。

今回は、この2つを「様々な角度」から徹底比較してみようと思います!

SageMaker(セージメイカー)

また、本記事では、「SageMaker」を使っていきたいと思います。

SageMaker では、「Jupyter Notebook」でビルトインカーネルを切り替えるだけで、どちらのフレームワークも簡単に動かすことができます。

インストールする手間もなく、おすすめです!

MNIST で比較!(Kerasの場合)

まずは、Keras について見ていきたいと思います。

「Sequential モード」の場合を想定して、解説していきます。

モデル定義パート

Keras の MNIST のモデル定義の部分は、以下のような感じになります。

Sequential を使ったモデル定義の場合だと、 model.add() で、使いたいレイヤを必要なパラメータと共に付け加えていくだけで簡単に定義できてしまいます。

また、一番最初のレイヤだけは input_shape=  の部分で画像サイズ(あるいはテンソルのサイズ)を入力していますが、そのあとのレイヤでは自動で計算してくれます。

モデル定義の見やすさは、Keras が有利でしょう。

トレーニングパート

続いて、トレーニング部分を見ていきます!

model.compile() でオプティマイザと損失関数をセットしています。

あとは model.fit()  をコールすれば、Epoch 回数分の学習が走ります。

verbose=1 としているので、中間結果としての損失と Accuracy が表示されます。

損失関数」と「オプティマイザ」は通常セットで変更するので、compile と fit パートを分けているのは、とても合理的な実装だと思います。

とてもシンプルかつ簡単で、そういった面では Keras が断然優位です。

MNISTで比較!(PyTorchの場合)

続いて、PyTorch で同じ部分を見てみましょう!

モデル定義パート

モデル定義は、次のようにクラスを使って行います。

PyTorch では、クラスの初期化と Forward 関数を定義する必要があります。

初期化では、使用する層についての定義を行い、Forward 関数でこれをつなげていきます。

各層への「入力サイズ」と「出力サイズ」をきちんと書いていかなければならない事が、Keras と大きく違います。

畳み込み計算などは、出力サイズを計算するのが少し面倒です。

このあたりについては、Keras を使った方が余計なことを考えずにネットを組むことができて楽です。

「全てが表示されていて、隠されたものがない」というあたりは、とても見通しが良い印象です。

Keras だと「あれ?これだけでいいの?」と逆に考え込んでしまうようなら、PyTorch の方がいいかもしれません。

トレーニングパート

続いて、トレーニング部分を見ていきます!

Keras とは大きく異なり、Epoch ごとに入力テンソルを Loader より取り込み、オプティマイザを初期化した後、計算を実行します。

計算が終わったら、

  1. 損失値を算出
  2. 微分
  3. オプティマイザをかける

という一連の流れをすべて記述する必用があります。

このあたりは、ほとんどが定型的なものなので、覚えてしまえば、それほど難しくはありません。

ただ、それぞれ微妙に調整しなければいけないので、調整のやり方がよくわからない初心者の方には少しハードルが高いかもしれません。

このやり方は、「Define by Run」と呼ばれるやり方ですが、「なんでこれが良いの?」と思う人もいるかもしれません。

ですが、実際にデバッグをするときには、こちらの方が圧倒的にやりやすいのです。

データの流れを追ってみる!

せっかくデータの流れが追えるので、少しだけデータの流れを追ってみましょう。

enumerate で「一回分のデータxバッチサイズ」の配列が datatarget に渡されます。

(ここでの  target とはラベルのことです。)

次に、「CPU」か「GPU」に渡されたデータを送ります。

データを渡したら、続いて、オプティマイザを初期化します。

一回のネットワークでの epoch は  output=model(data) で行われます。

epoch が走ったら(もしくは実際のネットワークでの加乗算が走ったら)、損失を計算し、微分をして、ネットワーク内部の Weight を更新していきます。

この流れは、どのニューラルネットでも同じ動きで、その動きが見える、データの流れが見えるというところは、PyTorch の素敵な部分の1つです。

デバッグがやりやすい

デバッグがやりやすいのも、特徴の1つです。

たとえば Jupyter であれば、問題のあったところに「pdb」を埋め込めば、その場でテンソルの内容を確認することができます!

これがそのデバッグ文ですね。

これをいれると、その場で pdb の入力ができるようになります。

試した後は、「c」を入力すると元に戻るので、コマンドを知らないうちは「c」を入力してください。

実際に output = model(data) のあとに、この行を挿入してみてください。

入力側のネットワークのTensorと、一回の演算結果での出力側の Tensor について簡単に見ることができます。

GPUへの切り替えが楽!

また、次の行もポイントで、実際に「GPU」を使用するテンソルを決めて、GPU に転送することも簡単に書けます。

Keras で GPU を使う場合は、バックエンドをインストールしなおすことが必要となり、それに比べると PyTorch は非常に楽です。

Keras の場合でも、SageMaker だとカーネルを切り替えるだけで済むので簡単ですが、そうでないない場合は断然、PyTorch が楽です。

このあたりを難しく感じるような初心者のころや、中身ははあまり考えたくない、もしくは必要ないという場合は、Keras をオススメします。

実際に自分でデータを持っていて、「アダプテーションを行うから、実際のデータをデバッグで見ないといけないんだ!」というような人には、PyTorch をオススメします。

「実装状況」で比較

世の中の実装状況を見ていると、感触では、PyTorch の方が実装が増えているように感じます。

が、実際はどうでしょうか?

実際に、本家のサンプルプログラム数を比較してみます。

すると、Keras だと主要なものが「5つ」。

それに対して、PyTorch の方は「12」もあります。

よって、PyTorch の方が数は多いということが言えます。

「ドキュメント」で比較

PyTorch は、昔からとても良いドキュメントが公式サイトに展開されています。

今回、2020年4月に1.5がリリースされてドキュメントが一新されましたが、さらに良くなっています。

【PyTorch ドキュメント】
https://pytorch.org/docs/stable/index.html

特に、「PyTorch Hub」と言われるサイトには、最新のリサーチ(研究)レベルのネットが多く掲載されています。

【PyTorch Hub】
https://pytorch.org/hub/

一方、Kerasでは…

Keras の方は、ドキュメントは PyTorch に比べるとやや薄いという感じがします。

ただ、一貫した思想に基づいて作られているので、こちらの方が好きな人もいるのではないかと感じます。

【Keras ドキュメント】
https://keras.io/about/

「記述性」で比較

記述性では、PyTorch と比較されがちな Keras。

ただ、Keras には、「ファンクションモード」という上級向けの機能があります。

これを使うと、ほぼ PyTorch と同じような記述性を持っています。

また、Callback を使いこなすと、PyTorch の学習ループでやっているようなことも書くことが出来ます。

kerasは、決して機能的に劣っているということはありません。

「学習のしやすさ」で比較

Keras は、初心者がとても学習しやすいものになっています。

これは Keras が、レイヤ(層)中心の考え方で作られているためです。

これは、層と層を繋げていくことで、「ニューラルネットが成り立っている」ということをよく理解できるからです。

また、書籍類についても、Keras の方が、ニューラルネットの解説をしてある本で良書のものが多いです。

PyTorch

PyTorch は、どちらかというと実践向けが多いというイメージです。

書籍類についても、実践的な上級者向けの内容が多いようです。

後編につづく!

今回は、人気の機械学習フレームワーク Keras と PyTorch を簡単に比較してみました。

  1. 初心者や情報系の大学院生などには、「Keras」
  2. 実際に実務でネットを組むような場合には、「PyTorch」

という感じが良いのかなと思いました。

機械学習に興味がある方は、ぜひ一度、Keras と PyTorch を試してみてください!

後編に続く

後編では、もうすこし深く、見ていきたいと思います!

こちらの記事もオススメ!

書いた人はこんな人

広告メディア事業部
広告メディア事業部
「好きを仕事にするエンジニア集団」の(株)ライトコードです!

ライトコードは、福岡、東京、大阪の3拠点で事業展開するIT企業です。
現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。
いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。

システム開発依頼・お見積もり大歓迎!

また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」「WEBディレクター」を積極採用中です!
インターンや新卒採用も行っております。

以下よりご応募をお待ちしております!
https://rightcode.co.jp/recruit

関連記事

採用情報

\ あの有名サービスに参画!? /

バックエンドエンジニア

\ クリエイティブの最前線 /

フロントエンドエンジニア

\ 世界を変える…! /

Androidエンジニア

\ みんなが使うアプリを創る /

iOSエンジニア