TensorFlow 演算の融合

概要

このページでは、TensorFlow の複合オペレーションを LiteRT の融合オペレーションに変換するために必要な設計と手順について説明します。このインフラストラクチャは汎用であり、TensorFlow の複合オペレーションを LiteRT の対応する融合オペレーションに変換することをサポートしています。

このインフラストラクチャの使用例として、TensorFlow RNN オペレーションの LiteRT への融合があります。詳しくは、こちらをご覧ください。

融合オペレーションとは

図形描画

TensorFlow オペレーションは、プリミティブ オペレーション(tf.add など)である場合と、他のプリミティブ オペレーション(tf.einsum など)から構成される場合があります。プリミティブ オペレーションは TensorFlow グラフ内の単一のノードとして表示されますが、複合オペレーションは TensorFlow グラフ内のノードのコレクションです。複合オペレーションの実行は、構成する各プリミティブ オペレーションの実行と同じです。

融合演算は、対応する複合演算内の各プリミティブ演算によって実行されるすべての計算を包含する単一の演算に対応します。

融合演算のメリット

融合オペレーションは、全体的な計算を最適化し、メモリ フットプリントを削減することで、基盤となるカーネル実装のパフォーマンスを最大化するために存在します。これは、特に低レイテンシの推論ワークロードやリソースが制約されたモバイル プラットフォームで非常に有用です。

融合演算は、量子化などの複雑な変換を定義するための高レベルのインターフェースも提供します。これがないと、より細かいレベルで実行することが不可能になったり、非常に難しくなったりします。

LiteRT には、上記の理由から、融合オペレーションのインスタンスが多数あります。これらの融合オペレーションは通常、ソース TensorFlow プログラムの複合オペレーションに対応します。LiteRT で単一の融合演算として実装される TensorFlow の複合演算の例としては、単方向および双方向シーケンス LSTM などのさまざまな RNN 演算、畳み込み(conv2d、バイアス追加、relu)、全結合(matmul、バイアス追加、relu)などがあります。LiteRT では、LSTM 量子化は現在、融合 LSTM オペレーションでのみ実装されています。

融合演算の課題

TensorFlow の複合オペレーションを LiteRT の融合オペレーションに変換することは難しい問題です。理由は次のとおりです。

  1. 複合オペレーションは、明確な境界のない一連のプリミティブ オペレーションとして TensorFlow グラフで表されます。このような複合オペレーションに対応するサブグラフを(パターン マッチングなどで)特定するのは非常に困難です。

  2. 融合された LiteRT オペレーションをターゲットとする TensorFlow 実装が複数存在する可能性があります。たとえば、TensorFlow には多くの LSTM 実装(Keras、Babelfish/lingvo など)があり、それぞれが異なるプリミティブ オペレーションで構成されていますが、それらはすべて LiteRT の同じ融合 LSTM オペレーションに変換できます。

そのため、融合オペレーションの変換は非常に困難であることがわかっています。

複合オペレーションを tf.function でラップする

多くの場合、モデルの一部を TFLite の単一のオペレーションにマッピングできます。これは、特定のオペレーション用に最適化された実装を記述する際のパフォーマンスに役立ちます。TFLite で融合オペレーションを作成するには、融合オペレーションを表すグラフの部分を特定し、tf.function でラップします。このとき、属性値 tfl_fusable_op を持つ属性 tf.functiontrue の値を持つ「experimental_implements」属性を指定します。カスタム オペレーションが属性を受け取る場合は、同じ experimental_implements の一部として渡します。

例:

def get_implements_signature():
  implements_signature = [
    # 'name' will be used as a name for the operation.
    'name: "my_custom_fused_op"',
    # attr "tfl_fusable_op" is required to be set with true value.
    'attr {key: "tfl_fusable_op" value { b: true } }',
    # Example attribute "example_option" that the op accepts.
    'attr {key: "example_option" value { i: %d } }' % 10
  ]
  return ' '.join(implements_signature)

@tf.function(experimental_implements=get_implements_signature())
def my_custom_fused_op(input_1, input_2):
  # An empty function that represents pre/post processing example that
  # is not represented as part of the Tensorflow graph.
  output_1 = tf.constant(0.0, dtype=tf.float32, name='first_output')
  output_2 = tf.constant(0.0, dtype=tf.float32, name='second_output')
  return output_1, output_2

class TestModel(tf.Module):
  def __init__(self):
    super(TestModel, self).__init__()
    self.conv_1 = tf.keras.layers.Conv2D(filters=1, kernel_size=(3, 3))
    self.conv_2 = tf.keras.layers.Conv2D(filters=1, kernel_size=(3, 3))

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[1, 28, 28, 3], dtype=tf.float32),
      tf.TensorSpec(shape=[1, 28, 28, 3], dtype=tf.float32),
  ])
  def simple_eval(self, input_a, input_b):
    return my_custom_fused_op(self.conv_1(input_a), self.conv_2(input_b))

tfl_fusable_op 属性はすでにこれを意味しているため、コンバータで allow_custom_ops を設定する必要はありません。

カスタム演算を実装し、TFLite インタープリタに登録する

融合演算を TFLite カスタム演算として実装します。手順をご覧ください。

op を登録する名前は、実装シグネチャの name 属性で指定された名前と類似している必要があります。

例の op の例は

  TfLiteRegistration reg = {};
  // This name must match the name specified in the implements signature.
  static constexpr char kOpName[] = "my_custom_fused_op";
  reg.custom_name = kOpName;
  reg.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
    // Add your code.
    return kTfLiteOk;
  };
  reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
    // Add your code.
    return kTfLiteOk;
  };
  reg.builtin_code = kTfLiteCustom;
  resolver->AddCustom(kOpName, &reg);

複合オペレーションから融合オペレーションへの変換(Advanced)

TensorFlow 複合オペレーションを LiteRT 融合オペレーションに変換する全体的なアーキテクチャは次のとおりです。

図形描画

複合オペレーションを tf.function でラップする

TensorFlow モデルのソースコードで、複合オペレーションを特定して抽象化し、experimental_implements 関数アノテーションを使用して tf.function にします。エンベディング検索の例をご覧ください。この関数はインターフェースを定義します。その引数は、変換ロジックの実装に使用する必要があります。

変換コードを記述する

変換コードは、implements アノテーション付きの関数のインターフェースに従って記述されます。エンベディング ルックアップの融合の例をご覧ください。概念的には、変換コードは、このインターフェースの複合実装を融合実装に置き換えます。

prepare-composite-functions パスで、変換コードをプラグインします。

より高度な使用法では、融合演算のオペランドを導出するために、複合演算のオペランドの複雑な変換を実装できます。例として、Keras LSTM 変換コードをご覧ください。

LiteRT に変換する

TFLiteConverter.from_saved_model API を使用して LiteRT に変換します。

仕組み

ここでは、LiteRT で融合オペレーションに変換する際の全体的な設計の概要について説明します。

TensorFlow でのオペレーションの構成

experimental_implements 関数属性で tf.function を使用すると、ユーザーは TensorFlow プリミティブ オペレーションを使用して新しいオペレーションを明示的に構成し、結果の複合オペレーションが実装するインターフェースを指定できます。これは、次のような情報を提供するため、非常に便利です。

  1. 基盤となる TensorFlow グラフの複合オペレーションの明確に定義された境界。
  2. このオペレーションが実装するインターフェースを明示的に指定します。tf.function の引数は、このインターフェースの引数に対応します。

たとえば、エンベディング検索を実装するために定義された複合オペレーションについて考えてみましょう。これは、LiteRT の融合オペレーションにマッピングされます。

  @tf.function(
        experimental_implements="embedding_lookup")
    def EmbFprop(embs, ids_vec):
      """Embedding forward prop.

      Effectively, it computes:
        num = size of ids_vec
        rets = zeros([num, embedding dim])
        for i in range(num):
          rets[i, :] = embs[ids_vec[i], :]
        return rets

      Args:
        embs: The embedding matrix.
        ids_vec: A vector of int32 embedding ids.

      Returns:
        The result of embedding lookups. A matrix of shape
        [num ids in ids_vec, embedding dims].
      """
      num = tf.shape(ids_vec)[0]
      rets = inplace_ops.empty([num] + emb_shape_suf, py_utils.FPropDtype(p))

      def EmbFpropLoop(i, embs, ids_vec, rets):
        # row_id = ids_vec[i]
        row_id = tf.gather(ids_vec, i)
        # row = embs[row_id]
        row = tf.reshape(tf.gather(embs, row_id), [1] + emb_shape_suf)
        # rets[i] = row
        rets = inplace_ops.alias_inplace_update(rets, [i], row)
        return embs, ids_vec, rets

      _, _, rets = functional_ops.For(
          start=0,
          limit=num,
          delta=1,
          inputs=[embs, ids_vec, rets],
          body=EmbFpropLoop,
          rewrite_with_while=compiled)
      if len(weight_shape) > 2:
        rets = tf.reshape(rets, [num, symbolic.ToStatic(p.embedding_dim)])
      return rets

上記のように、tf.function を介してモデルで複合オペレーションを使用することで、このようなオペレーションを融合された LiteRT オペレーションに識別して変換するための一般的なインフラストラクチャを構築できます。

LiteRT コンバータの拡張

今年の初めにリリースされた LiteRT コンバータは、すべての変数が対応する定数値に置き換えられたグラフとして TensorFlow モデルをインポートすることのみをサポートしていました。このようなグラフでは、変数を定数に変換できるようにすべての関数がインライン化されているため、オペレーションの融合では機能しません。

変換プロセスで experimental_implements 機能とともに tf.function を活用するには、変換プロセスの後半まで関数を保持する必要があります。

そのため、複合オペレーションの融合ユースケースをサポートするために、コンバータで TensorFlow モデルをインポートして変換する新しいワークフローを実装しました。具体的には、以下の新機能が追加されます。

  1. TensorFlow 保存済みモデルを MLIR にインポートする
  2. 複合オペレーションを融合する
  3. 変数の可変性分析
  4. すべての読み取り専用変数をフリーズする

これにより、関数インライン化と変数フリーズの前に、複合オペレーションを表す関数を使用してオペレーション フュージョンを実行できます。

オペレーションの融合の実装

オペレーション融合パスについて詳しく見ていきましょう。このパスは次の処理を行います。

  1. MLIR モジュールのすべての関数をループ処理します。
  2. 関数に tf._implements 属性がある場合、属性値に基づいて、適切なオペレーション融合ユーティリティを呼び出します。
  3. オペレーション融合ユーティリティは、関数のオペランドと属性(変換のインターフェースとして機能)に対して動作し、関数の本体を融合オペレーションを含む同等の関数本体に置き換えます。
  4. 多くの場合、置き換えられた本体には、融合されたオペレーション以外のオペレーションが含まれます。これらは、融合オペレーションのオペランドを取得するために、関数のオペランドに対する静的変換に対応します。これらの計算はすべて定数フォールディングで削除できるため、エクスポートされた flatbuffer には存在しません。この flatbuffer には、融合されたオペレーションのみが存在します。

パスのメイン ワークフローを示すコード スニペットは次のとおりです。

void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
                                                        StringAttr attr) {
  if (attr.getValue() == "embedding_lookup") {
    func.eraseBody();
    func.addEntryBlock();
    // Convert the composite embedding_lookup function body to a
    // TFLite fused embedding_lookup op.
    ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
    if (failed(convert_embedded_lookup.VerifySignature())) {
      return signalPassFailure();
    }
    convert_embedded_lookup.RewriteFunc();
  } else if (attr.getValue() == mlir::TFL::kKerasLstm) {
     func.eraseBody();
     func.addEntryBlock();
     OpBuilder builder(func.getBody());
     if (failed(ConvertKerasLSTMLayer(func, &builder))) {
       return signalPassFailure();
     }
  } else if (.....) /* Other fusions can plug in here */
}

次のコード スニペットは、関数を変換インターフェースとして活用して、この複合オペレーションを LiteRT の融合オペレーションにマッピングする方法を示しています。

void RewriteFunc() {
    Value lookup = func_.getArgument(1);
    Value value = func_.getArgument(0);
    auto output_type = func_.getType().getResult(0);

    OpBuilder builder(func_.getBody());
    auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
        func_.getLoc(), output_type, lookup, value);

    builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
  }