使用 Google Tensor SDK 编译模型

借助 Google Tensor SDK,您可以直接在 Google Tensor 的专用 TPU 上优化和运行机器学习 (ML) 模型。除了标准机器学习模型之外,您还可以在 LiteRT 工作流中编译和加速语言模型。

对于某些预优化模型,我们还提供了一个额外的编译器选项二进制文件,您可以将其纳入工作流程中。为了在 Google Tensor 上实现最佳性能,我们建议在编译期间使用以下优化标志:

Google Tensor 的编译标志

应用特定的配置标志来定制性能和资源使用情况,从而优化模型编译流程。 您可以在 LiteRT Python 工作流中使用这些参数来调整 PyTorch 和 TFLite 模型的编译行为。

标志 要求 说明 默认值
google_tensor_truncation_type 可选 为浮点运算设置目标数据类型。
支持的值:auto(默认值)、bfloat16、half、no_truncation
自动
google_tensor_sharding_intensity 可选 控制模型拆分以进行并行处理的积极程度。
选项:minimal、moderate、extensive、maximum。
极简
google_tensor_int64_to_int32 可选 设置为 true 可让编译器将 64 位整数转换为 32 位整数,这对于某些模型可能是必需的。 错误
google_tensor_enable_large_model_support 可选 如果模型大于 2 GB,则设置为 true。 错误
google_tensor_enable_4bit_compilation 可选 支持编译包含 4 位参数的卷积运算模型。 错误
google_tensor_extra_options_path 可选 以二进制文件形式传递额外的编译器选项。 ""(空)

您可以将这些标志与 LiteRT Python 流搭配使用,如以下代码段所示:

  • 使用 ai_edge_torch 编译 PyTorch 模型时

    compiled_models = (
      ai_edge_torch.experimental_add_compilation_backend(tensor_g5_target,
      flag_name1=value1,
      flag_name2=value2, ...).convert(
        channel_last_selfie_segmentation,
        sample_input))
    
  • 编译 tflite 模型时

    compiled_models = aot_lib.aot_compile(
        tflite_model_path,
        target=[tensor_g5_target],
        flag_name1=value1,
        flag_name2=value2, ...)
    

用法示例

在以下示例中,使用了 google_tensor_truncation_type="half" 标志:

  compiled_models = aot_lib.aot_compile(
      tflite_model_path,
      target=[tensor_g5_target],
      keep_going=False,
      google_tensor_truncation_type="half"
  )

如需了解详情,请参阅 LiteRT AOT Colab

为 Google Tensor 编译语言模型

如需为 Google Tensor 编译语言模型,请按照 NPU AOT 编译中的说明操作。

如需导出 LLM 以用于 Google Tensor TPU,请按照示例操作,了解 NPU 编译所需的其他标志。

示例

litert-torch export-hf \
  --model=google/gemma-3-270m-it \
  --output_dir=/tmp/gemma3-270m-google-tensor-g5 \
  --split_cache \
  --externalize_embedder \
  --prefill_lengths=128, \
  --cache_length=1280 \
  --quantization_recipe="weight_only_wi8_afp32" \
  --aot_backend=GOOGLE \
  --aot_soc_model=Tensor_G5 \
  --aot_compilation_config_dict='{"google_tensor_enable_large_model_support": True}'