TensorFlow RNN から LiteRT に変換

概要

LiteRT は、TensorFlow RNN モデルを LiteRT の融合 LSTM オペレーションに変換することをサポートしています。融合演算は、基盤となるカーネル実装のパフォーマンスを最大化し、量子化などの複雑な変換を定義するための上位レベルのインターフェースを提供するために存在します。

TensorFlow には RNN API のバリエーションが多数あるため、Google のアプローチは次の 2 つに分かれています。

  1. Keras LSTM などの標準 TensorFlow RNN API のネイティブ サポートを提供します。これが推奨のオプションです。
  2. ユーザー定義の RNN 実装を接続して LiteRT に変換するための 変換インフラストラクチャへのインターフェースを提供します。lingvo の LSTMCellSimpleLayerNormalizedLSTMCellSimple RNN インターフェースを使用して、このような変換を行う例をいくつか用意しています。

Converter API

この機能は TensorFlow 2.3 リリースの一部です。tf-nightly pip または head からも入手できます。

この変換機能は、SavedModel を介して LiteRT に変換する場合、または Keras モデルから直接変換する場合に使用できます。使用例をご覧ください。

保存済みモデルから

# build a saved model. Here concrete_function is the exported function
# corresponding to the TensorFlow model containing one or more
# Keras LSTM layers.
saved_model, saved_model_dir = build_saved_model_lstm(...)
saved_model.save(saved_model_dir, save_format="tf", signatures=concrete_func)

# Convert the model.
converter = TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

Keras モデルから

# build a Keras model
keras_model = build_keras_lstm(...)

# Convert the model.
converter = TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()

Keras LSTM から LiteRT への Colab では、LiteRT インタープリタを使用したエンドツーエンドの使用方法を示しています。

TensorFlow RNN API のサポート

Keras LSTM から LiteRT への変換をすぐにサポートします。この仕組みの詳細については、Keras LSTM インターフェース と、こちらの変換ロジックをご覧ください。

また、Keras オペレーション定義に関して LiteRT の LSTM コントラクトをハイライトすることも重要です。

  1. 入力テンソルのディメンション 0 はバッチサイズです。
  2. recurrent_weight テンソルのディメンション 0 は出力の数です。
  3. weight テンソルと recurrent_kernel テンソルが転置されます。
  4. 転置された重み、転置された recurrent_kernel、バイアス テンソルは、ディメンション 0 に沿って 4 つの等しいサイズのテンソルに分割されます。これらは、入力ゲート、忘却ゲート、セル、出力ゲートに対応します。

Keras LSTM バリアント

Time major

ユーザーは、時間優先または時間優先なしを選択できます。Keras LSTM は、関数定義属性に時間優先属性を追加します。単方向シーケンス LSTM の場合、unidirecional_sequence_lstm の時間メジャー属性にマッピングするだけです。

双方向 LSTM

双方向 LSTM は、順方向と逆方向の 2 つの Keras LSTM レイヤで実装できます。例については、こちらをご覧ください。go_backward 属性が検出されると、逆方向 LSTM として認識され、順方向 LSTM と逆方向 LSTM がグループ化されます。これは今後の作業です。現在、これにより LiteRT モデルに 2 つの UnidirectionalSequenceLSTM オペレーションが作成されます。

ユーザー定義の LSTM 変換の例

LiteRT には、ユーザー定義の LSTM 実装を変換する方法も用意されています。ここでは、Lingvo の LSTM を実装例として使用します。詳細については、lingvo.LSTMCellSimple インターフェースと、こちらの変換ロジックをご覧ください。また、lingvo.LayerNormalizedLSTMCellSimple インターフェースと、その変換ロジックはこちらで、Lingvo の別の LSTM 定義の例も提供しています。

LiteRT への「独自の TensorFlow RNN の持ち込み」

ユーザーの RNN インターフェースが標準でサポートされているものと異なる場合は、次の 2 つのオプションがあります。

オプション 1: TensorFlow Python でアダプタ コードを記述して、RNN インターフェースを Keras RNN インターフェースに適応させます。つまり、生成された RNN インターフェースの関数に tf_implements アノテーションが付いた tf.function があり、これは Keras LSTM レイヤによって生成されたものと同じです。その後は、Keras LSTM に使用したのと同じ変換 API が機能します。

オプション 2: 上記が不可能な場合(Keras LSTM に、LiteRT の融合 LSTM op(レイヤ正規化など)で現在公開されている機能の一部がない場合など)は、カスタム変換コードを記述して LiteRT コンバータを拡張し、こちらの prepare-composite-functions MLIR パスにプラグインします。関数のインターフェースは API 契約として扱われるべきであり、融合された LiteRT LSTM オペレーションに変換するために必要な引数(入力、バイアス、重み、射影、レイヤ正規化など)を含むべきです。この関数に引数として渡されるテンソルは、既知のランク(MLIR の RankedTensorType など)を持つことが望ましいです。これにより、これらのテンソルを RankedTensorType として想定できる変換コードを簡単に記述できるようになり、融合された LiteRT オペレーターのオペランドに対応するランク付きテンソルに変換できます。

このような変換フローの完全な例は、Lingvo の LSTMCellSimple から LiteRT への変換です。

Lingvo の LSTMCellSimple はこちらで定義されています。この LSTM セルでトレーニングされたモデルは、次のように LiteRT に変換できます。

  1. LSTMCellSimple のすべての使用を、そのようにラベル付けされた tf_implements アノテーションを含む tf.function でラップします(たとえば、lingvo.LSTMCellSimple は適切なアノテーション名です)。生成された tf.function が、変換コードで想定される関数のインターフェースと一致していることを確認します。これは、アノテーションとコンバージョン コードを追加するモデル作成者間の契約です。
  2. prepare-composite-functions パスを拡張して、カスタム複合演算を LiteRT 融合 LSTM 演算変換にプラグインします。LSTMCellSimple 変換コードをご覧ください。

    変換契約:

  3. 重みテンソルと射影テンソルが転置されます。

  4. {input, recurrent} から {cell, input gate, forget gate, output gate} は、転置された重みテンソルをスライスすることで抽出されます。

  5. {bias} から {cell, input gate, forget gate, output gate} は、バイアス テンソルをスライスして抽出されます。

  6. 投影は、転置された投影テンソルをスライスして抽出されます。

  7. 同様の変換が LayerNormalizedLSTMCellSimple に書き込まれます。

  8. LiteRT 変換インフラストラクチャの残りの部分(定義されているすべての MLIR パスや LiteRT flatbuffer への最終エクスポートなど)は再利用できます。

既知の問題と制限事項

  1. 現在、サポートされているのはステートレス Keras LSTM(Keras のデフォルトの動作)の変換のみです。ステートフル Keras LSTM の変換は今後の作業です。
  2. 基盤となるステートレス Keras LSTM レイヤを使用してステートフル Keras LSTM レイヤをモデル化し、ユーザー プログラムで状態を明示的に管理することは可能です。このような TensorFlow プログラムは、ここで説明する機能を使用して LiteRT に変換できます。
  3. 双方向 LSTM は現在、LiteRT で 2 つの UnidirectionalSequenceLSTM オペレーションとしてモデル化されています。これは単一の BidirectionalSequenceLSTM オペレーションに置き換えられます。