概要
LiteRT は TensorFlow RNN モデルを LiteRT のモデルに変換できます。 統合 LSTM 演算です。融合オペレーションは、VM のパフォーマンスを カーネルの実装をサポートするとともに、より高レベルの インターフェースを使用して、量子化などの複雑な変換を定義します。
TensorFlow には RNN API の多くのバリエーションがあるため、Google のアプローチは 2 つ:
- Keras LSTM などの標準の TensorFlow RNN API のネイティブ サポートを提供する。 これが推奨のオプションです。
- 以下のコンバージョン インフラストラクチャへのインターフェース を提供する 差し込んで変換するユーザー定義 RNN 実装 LiteRT。すぐに利用できる例をいくつか コンバージョンを LSTMCellSimple および LayerNormalizedLSTMCellSimple RNN インターフェース。
コンバータ API
この機能は TensorFlow 2.3 リリースの一部です。また、 tf-nightly pip、または head から実行します。
この変換機能は、LiteRT に変換するときに使用できます。 Keras モデルから直接呼び出すこともできます。使用例をご覧ください。
保存済みモデルから
# build a saved model. Here concrete_function is the exported function
# corresponding to the TensorFlow model containing one or more
# Keras LSTM layers.
saved_model, saved_model_dir = build_saved_model_lstm(...)
saved_model.save(saved_model_dir, save_format="tf", signatures=concrete_func)
# Convert the model.
converter = TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
Keras モデルから
# build a Keras model
keras_model = build_keras_lstm(...)
# Convert the model.
converter = TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
例
Keras LSTM から LiteRT Colab LiteRT インタープリタでのエンドツーエンドの使用方法を示します。
サポートされている TensorFlow RNN API
Keras LSTM 変換(推奨)
Keras LSTM から LiteRT へのすぐに使える変換がサポートされています。対象 詳しくは、こちらの Keras LSTM インターフェース と変換ロジックに反映されます。 こちらをご覧ください。
LiteRT の LSTM 契約を強調することも重要です。 Keras オペレーション定義に追加します。
- 入力テンソルのディメンション 0 はバッチサイズです。
- recurrent_weight テンソルの次元 0 は 出力です。
- weight テンソルと recurrent_kernel テンソルは転置されます。
- 転置された重み、転置された recurrent_kernel、bias テンソルは、 次元 0 に沿って 4 つの等しいサイズのテンソルに分割されます。これらは 入力ゲート、忘れゲート、セル、出力ゲート。
Keras LSTM バリアント
メジャー タイム
ユーザーは時間優先か、時間優先なしかを選択できます。Keras LSTM では、従来の 属性を宣言します。単方向シーケンス LSTM の場合、 簡単に unidirecional_sequence_lstm の タイムメジャー属性。
双方向 LSTM
双方向 LSTM は、Keras LSTM レイヤの 2 つを使用して実装できます。 1 つは前方、もう 1 つは後方です。例は こちらをご覧ください。 go_backward 属性があると、後方 LSTM と認識され、 グループ化して組み合わせて使用することもできます。これは今後の取り組みです。現在、 これにより、LiteRT に 2 つの UnidirectionalSequenceLSTM オペレーションが作成されます。 モデルです。
ユーザー定義の LSTM 変換の例
LiteRT では、ユーザー定義の LSTM を変換することもできます。 あります。ここでは、Lingvo の LSTM を使用して、 確認します。詳しくは、 lingvo.LSTMCellSimple インターフェース 変換ロジックは こちらをご覧ください。 また、 lingvo.LayerNormalizedLSTMCellSimple インターフェース その変換ロジックは こちらをご覧ください。
LiteRT への「独自の TensorFlow RNN の使用」
ユーザーの RNN インターフェースが標準でサポートされているものと異なる場合、 2 つのオプションがあります
オプション 1: TensorFlow Python でアダプタコードを作成して RNN インターフェースを適応させる Keras RNN インターフェースに渡します。つまり、tf.function には、 tf_ implementss アノテーション 生成された RNN インターフェースの関数は、モデルによって 実装されています。その後は Keras LSTM で使用していたものと同じコンバージョン API が 機能します。
オプション 2: 上記を行うことができない場合(例: Keras LSTM に一部が不足している場合) LiteRT の融合 LSTM 演算によって現在公開されている レイヤの正規化)を作成し、LiteRT コンバータを拡張します。 変換し、これを prepare-composite-Functions 関数に MLIR-pass こちらをご覧ください。 関数のインターフェースは API コントラクトのように扱う必要があります。 融合された LiteRT LSTM に変換するために必要な引数が含まれています。 入力、バイアス、重み、投影、レイヤ正規化などの演算が含まれます。 この関数の引数として渡されるテンソルが既知の (MLIR では RankedTensorType)。これにより、 これらのテンソルを RankedTensorType と仮定し、 融合された LiteRT に対応するランク付けされたテンソルに変換し、 オペランドを渡して
こうしたコンバージョン フローの完全な例が、Lingvo の LSTMCellSimple です。 LiteRT 変換。
Lingvo の LSTMCellSimple は こちらをご覧ください。 この LSTM セルでトレーニングされたモデルは、次のように LiteRT に変換できます。 次のようになります。
- LSTMCellSimple の使用をすべて tf.function でラップし、tf_implementations を使用 そのようにラベル付けされたアノテーション(たとえば、lingvo.LSTMCellSimple は 適切なアノテーション名をここに入力してください)。生成された tf.function が、 変換コードに想定されている関数のインターフェースと一致します。この モデル作成者がアノテーションを追加するコントラクトと、 なります。
prepare-composite-functions パスを拡張してカスタム複合オペレーションをプラグインする LiteRT Fused LSTM op 変換です。詳しくは、 LSTMCellSimple なります。
変換契約:
重みテンソルと射影テンソルは転置されます。
{input, recurrent}から{セル、インプット ゲート、フォーゲット ゲート、出力 gate} は、転置された重みテンソルをスライスして抽出します。
{bias} に対する {bias} は、 バイアステンソルをスライスして抽出します
射影は、転置された射影テンソルをスライスすることによって抽出されます。
同様のコンバージョンが作成されています LayerNormalizedLSTMCellSimple.
次のものを含む、LiteRT 変換インフラストラクチャの残りの部分は、 MLIR パス LiteRT フラットバッファへの最終的なエクスポートは、 されます。
既知の問題/制限事項
- 現時点では、ステートレス Keras LSTM(デフォルト 動作)。ステートフルな Keras LSTM 変換は将来的な機能です。
- ただし、 基盤となるステートレス Keras LSTM レイヤを作成し、Google Cloud で 必要があります。このような TensorFlow プログラムは、 ここで説明する機能を使用する LiteRT。
- Bidirectional LSTM は現在、2 つの UnidirectionalSequenceLSTM としてモデル化されています。 LiteRT での操作を行えます。これは Chronicle SOAR サーバーに BidirectionalSequenceLSTM 演算