LiteRT JAX 支援總覽

LiteRT 運用 TensorFlow 生態系統,提供轉換 JAX 模型以進行裝置端推論的路徑。這個程序包含兩個轉換步驟:首先從 JAX 轉換為 TensorFlow SavedModel,然後從 SavedModel 轉換為 .tflite 格式。

轉換程序

  1. 使用 jax2tf 將 JAX 轉換為 TensorFlow SavedModel:第一步是將 JAX 模型轉換為 TensorFlow SavedModel 格式。這項作業是使用 jax2tf 工具完成,這是實驗性的 JAX 實驗功能。jax2tf 可將 JAX 函式轉換為 TensorFlow 圖表。

    如需如何使用 jax2tf 的詳細操作說明和範例,請參閱官方 jax2tf 說明文件: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    這個程序通常會將 JAX 模型的預測函式包裝在 jax2tf.convert 中,然後使用 TensorFlow 的 tf.saved_model.save 儲存。

  2. TensorFlow SavedModel 轉換為 TFLite:模型採用 TensorFlow SavedModel 格式後,即可使用標準的 TensorFlow Lite 轉換工具,將模型轉換為 TFLite 格式。這個程序會針對裝置端執行作業最佳化模型,縮減模型大小並提升效能。

    如需將 TensorFlow SavedModel 轉換為 TFLite 的詳細操作說明,請參閱 TensorFlow 模型轉換指南

    本指南涵蓋轉換程序的各種選項和最佳做法,包括量化和其他最佳化做法。

完成這兩個步驟後,您就能使用 LiteRT 執行階段,在邊緣裝置上有效部署以 JAX 開發的模型。