Google Tensor SDK でモデルをコンパイルする

Google Tensor SDK を使用すると、Google Tensor の専用 TPU で機械学習(ML)モデルを直接最適化して実行できます。標準の ML モデルに加えて、LiteRT ワークフロー内で言語モデルをコンパイルして高速化できます。

事前最適化された特定のモデルについては、ワークフローに含めることができるコンパイラ オプションのバイナリ ファイルも提供しています。Google Tensor で最適なパフォーマンスを実現するには、コンパイル時に次の最適化フラグを使用することをおすすめします。

Google Tensor のコンパイル フラグ

特定の構成フラグを適用して、パフォーマンスとリソース使用量を調整することで、モデル コンパイル プロセスを最適化します。これらのパラメータは、LiteRT Python ワークフロー内で使用して、PyTorch モデルと TFLite モデルの両方のコンパイル動作を調整できます。

フラグ 要件 説明 デフォルト値
google_tensor_truncation_type 省略可 浮動小数点演算のターゲット データ型を設定します。
サポートされている値: auto(デフォルト)、bfloat16、half、no_truncation
自動
google_tensor_sharding_intensity 省略可 並列処理のためにモデルをどの程度積極的に分割するかを制御します。
オプション: minimal、moderate、extensive、maximum。
ミニマル
google_tensor_int64_to_int32 省略可 コンパイラが 64 ビット整数を 32 ビット整数に変換できるようにするには、true に設定します。これは一部のモデルで必要になることがあります。 誤り
google_tensor_enable_large_model_support 省略可 モデルが 2 GB を超える場合は true に設定します。 誤り
google_tensor_enable_4bit_compilation 省略可 4 ビット パラメータを含む畳み込み演算でモデルのコンパイルを有効にします。 誤り
google_tensor_extra_options_path 省略可 追加のコンパイラ オプションをバイナリ ファイルとして渡します。 ""(空)

次のコード スニペットに示すように、これらのフラグを LiteRT Python フローで使用できます。

  • ai_edge_torch を使用して pytorch モデルをコンパイルする場合

    compiled_models = (
      ai_edge_torch.experimental_add_compilation_backend(tensor_g5_target,
      flag_name1=value1,
      flag_name2=value2, ...).convert(
        channel_last_selfie_segmentation,
        sample_input))
    
  • tflite モデルをコンパイルする場合

    compiled_models = aot_lib.aot_compile(
        tflite_model_path,
        target=[tensor_g5_target],
        flag_name1=value1,
        flag_name2=value2, ...)
    

使用例

次の例では、google_tensor_truncation_type="half" フラグを使用しています。

  compiled_models = aot_lib.aot_compile(
      tflite_model_path,
      target=[tensor_g5_target],
      keep_going=False,
      google_tensor_truncation_type="half"
  )

詳しくは、LiteRT AOT Colab をご覧ください。

Google Tensor 用に言語モデルをコンパイルする

Google Tensor 用の言語モデルをコンパイルするには、NPU AOT コンパイルの手順に沿って操作します。

Google Tensor TPU 用に LLM をエクスポートするには、NPU コンパイルに必要な追加フラグの例に沿って操作します。

例:

litert-torch export-hf \
  --model=google/gemma-3-270m-it \
  --output_dir=/tmp/gemma3-270m-google-tensor-g5 \
  --split_cache \
  --externalize_embedder \
  --prefill_lengths=128, \
  --cache_length=1280 \
  --quantization_recipe="weight_only_wi8_afp32" \
  --aot_backend=GOOGLE \
  --aot_soc_model=Tensor_G5 \
  --aot_compilation_config_dict='{"google_tensor_enable_large_model_support": True}'