TensorFlowのモデルファイル.pbを.tfliteに変換するときの注意点
IT技術
TensorFlowとは
TensorFlow(テンサーフロー、テンソルフロー)とは、Googleが開発した、機械学習に用いるためのソフトウェアライブラリです。
現在、TensorFlowの公式ドキュメントによると、利用するモデル・ファイル「〇〇.pb」は、アプリにセットアップする際「〇〇.tflite」で設定することを推奨中です。
それに伴ってTensorFlowで作成される「〇〇.pb」を「〇〇.tflite」に変換する必要があります。
さて今回は、実際に「〇〇.tflite」に変換してみて、その過程で上手くいかなかった点などの情報をまとめたいと思います。
【TensorFlowの公式ドキュメント】
https://www.tensorflow.org/
「.pb」と「.tflite」
変換する前に、「.pb」と「.tflite」について少し復習しておきたいと思います。
.pbとは
.pbは、「protobuf」の略で、さらにprotobufも「Protocol Buffers」という用語の略という2段階略語です。
.pbの中身としては、TensorFlowによるラーニングのグラフ定義とモデルの重みを収納します。
機械学習時の画像の特徴や、音声の特徴などを記録し、入力値に対して予測値を返すことが可能です。
.tfliteとは
それに対して.tflite は、「Android」や「iOS」、「ラズベリーパイ」での利用を目的としたモデル・ファイルの形式で、予測結果を素早く算出できるという特徴があります。
現在、GitHubやWeb上でサンプルのAIアプリはたくさん公開されていますが、モデルには『〇〇.pb』や『◇◇.h5』『△△.lite』など、いくつかの拡張子でデータ・セットされています。
TensorFlow開発側は、『〇〇.pb』などではなく『☆☆.tflite』の利用を推奨中。
そのため『.tflite』を学習し、使う必要がでてきています。
.pbから.tfliteに変換する方法
コンバート方法 | 内容 | サンプル |
コマンド | tflite_convert | Codelabs |
コマンド | TOCO | GitHub |
PythonAPI | クラス: tf.lite.TFLiteConverter | TensorFlow公式 |
TensorFlowによって生成されるモデル・ファイルの「.pb」は、大きく分けて上記のようなパターンで「.tflite」に変換可能です。
公式ドキュメントでは、一番下の「クラス:tf.lite.TFLiteConverter」の使用を推奨しています。
3種類のクラス
そして「クラス:tf.lite.TFLiteConverter」では、モデルに応じて3種類のクラスが用意されています。
- TFLiteConverter.from_saved_model(): SavedModel ディレクトリ を変換、MobileNetなどで。
- TFLiteConverter.from_keras_model(): tf.kerasモデルを変換、MNISTなど。
- TFLiteConverter.from_concrete_functions(): 具象関数を変換。
そして各クラスの記述方法は違って、予め定められた方法で記述する必要があります。
例えば一番上の from_saved_model() クラスを用いる場合は、モデルを一度 tf.saved_model.save() で読み込んでおく必要があります。
この工程は、他のfrom_keras_model() や from_concrete_functions() では必要ないため、忘れがちになります。
古いクラスのご紹介
- tf.contrib.lite.TocoConverter.from_keras_model_file()
- tf.contrib.lite.TFLiteConverter.from_keras_model_file()
- tf.lite.TFLiteConverter.from_keras_model_file()
上記は、サンプルコードなどで登場する書き方になります。
古い形の変換方法になりますので、TensorFlowのバージョン等留意する必要があります。
.pbを.tfliteに変換
公式ドキュメントで紹介されているコードを元に、 「SavedModel」と「tf.keras」の2つを使ってみましょう。
【今回ご紹介するプログラムの内容を収めた Google Colab】
https://colab.research.google.com/drive/135W5wfFkVSbNjpCMftdVoVIwwa4VCs9k
サンプルコードを実行
エラーが発生
公式ドキュメントのサンプルコードを実行しただけでは、実はエラーという結果になります。
ValueError: Attempted to save a function b'__inference_<lambda>_45' which references a symbolic Tensor Tensor("Variable/read:0", shape=(), dtype=float32) that is not a simple constant. This is not supported.
理由は、書かれているコードは、TensorFlow のバージョン2系なのに対して、サンプルコード上の import tensorflow as tf では 1.15 がインストールされるためです。
サンプル通りの import文ではなく、TensorFlow2系をインストールする必要があります。
print("TF version:", tf.__version__) でバージョンを確認してみて下さい。
TensorFlow2系のインストール
TensorFlow2系のインストールは、try文でバージョンの有無をチェックし、インストールする TensorFlow のバージョンを指定します。
上記コードを実行すると、TensorFlow2系をインストールされます。
SavedModel形式で.pbを.tfliteへ変換
公式ドキュメントのコードをベースに、.pb ファイルの出力先の確認と .tflite のファイル作成を追加。
tf.saved_model.save() 以降が変換処理で、書き方が参考になると思います。
モデルを tf.saved_model.save() に入れてから tf.lite.TFLiteConverter.from_saved_model() していますね。
/tmp/test_saved_model 内に保存された saved_model.pb、そして open() 関数で書き出される.tflite、実存するファイルを確認することで .pb や .tflite の存在を確認できますね。
MobileNet(転移学習)を使用するケースで、こちらのSavedModel形式はよく見かけます。
tf.keras方式でモデルを.tfliteへ変換
サンプルコードを実行しただけでは、ラーニング後のモデルファイルを確認できません。
model.save('model.h5') を実行すると tf.keras方式のモデルを出力することが可能です。
こちらの場合も最終的にopen() 関数を使用することで、 .tflite を保存可能になります。
さいごに
tfliteの作成方法は開発段階のところもあったりで、やや不安定な印象も受けます。
また、GitHubなどで公開されているプロジェクトでは、「.pb」のままアプリにセットしているケースもあります。
情報が散策し扱いにくい印象を受ける「tflite:ですが、公式ドキュメントをベースにスマートに使いこなしたいですね!
こちらの記事もオススメ!
2020.07.28機械学習 特集知識編人工知能・機械学習でよく使われるワード徹底まとめ!機械学習の元祖「パーセプトロン」とは?【人工知能】ニューラルネ...
2020.07.17ライトコード的「やってみた!」シリーズ「やってみた!」を集めました!(株)ライトコードが今まで作ってきた「やってみた!」記事を集めてみました!※作成日が新し...
ライトコードでは、エンジニアを積極採用中!
ライトコードでは、エンジニアを積極採用しています!社長と一杯しながらお話しする機会もご用意しております。そのほかカジュアル面談等もございますので、くわしくは採用情報をご確認ください。
採用情報へ
「好きを仕事にするエンジニア集団」の(株)ライトコードです! ライトコードは、福岡、東京、大阪、名古屋の4拠点で事業展開するIT企業です。 現在は、国内を代表する大手IT企業を取引先にもち、ITシステムの受託事業が中心。 いずれも直取引で、月間PV数1億を超えるWebサービスのシステム開発・運営、インフラの構築・運用に携わっています。 システム開発依頼・お見積もり大歓迎! また、現在「WEBエンジニア」「モバイルエンジニア」「営業」「WEBデザイナー」を積極採用中です! インターンや新卒採用も行っております。 以下よりご応募をお待ちしております! https://rightcode.co.jp/recruit