Compile models with Google Tensor SDK

The Google Tensor SDK lets you optimize and run machine learning (ML) models directly on Google Tensor's dedicated TPU. In addition to standard ML models, you can compile and accelerate language models within your LiteRT workflow.

For certain pre-optimized models, we also provide an extra compiler options binary file that you can include in your workflow. To achieve optimal performance on Google Tensor, we recommend using the following optimization flags during compilation:

Compilation flags for Google Tensor

Optimize your model compilation process by applying specific configuration flags to tailor performance and resource usage. You can use these parameters within your LiteRT Python workflow to adjust compilation behavior for both PyTorch and TFLite models.

Flag Requirement Description Default Value
google_tensor_truncation_type Optional Sets the target data type for floating-point operations.
Supported values: auto (default), bfloat16, half, no_truncation
auto
google_tensor_sharding_intensity Optional Controls how aggressively the model is split for parallel processing.
Options: minimal, moderate, extensive, maximum.
minimal
google_tensor_int64_to_int32 Optional Set to true to allow the compiler to convert 64-bit integers to 32-bit integers, which may be necessary for some models. False
google_tensor_enable_large_model_support Optional Set to true if your model is larger than 2GB. False
google_tensor_enable_4bit_compilation Optional Enables compilation of models with convolution operations containing 4 bit parameters. False
google_tensor_extra_options_path Optional Passes extra compiler options as a binary file. "" (Empty)

You can use these flags with LiteRT Python flow as shown in the following code snippets:

  • When compiling a pytorch model using ai_edge_torch

    compiled_models = (
      ai_edge_torch.experimental_add_compilation_backend(tensor_g5_target,
      flag_name1=value1,
      flag_name2=value2, ...).convert(
        channel_last_selfie_segmentation,
        sample_input))
    
  • When compiling a tflite model

    compiled_models = aot_lib.aot_compile(
        tflite_model_path,
        target=[tensor_g5_target],
        flag_name1=value1,
        flag_name2=value2, ...)
    

Usage example

In the following example, google_tensor_truncation_type="half" flag is used:

  compiled_models = aot_lib.aot_compile(
      tflite_model_path,
      target=[tensor_g5_target],
      keep_going=False,
      google_tensor_truncation_type="half"
  )

See LiteRT AOT Colab for more information.

Compile language models for Google Tensor

To compile language models for Google Tensor, follow the instructions in NPU AOT compilation.

To export LLMs for Google Tensor TPUs, follow the example for the additional flags required for NPU compilation.

Example:

litert-torch export-hf \
  --model=google/gemma-3-270m-it \
  --output_dir=/tmp/gemma3-270m-google-tensor-g5 \
  --split_cache \
  --externalize_embedder \
  --prefill_lengths=128, \
  --cache_length=1280 \
  --quantization_recipe="weight_only_wi8_afp32" \
  --aot_backend=GOOGLE \
  --aot_soc_model=Tensor_G5 \
  --aot_compilation_config_dict='{"google_tensor_enable_large_model_support": True}'