LiteRT JAX 支持概览

LiteRT 提供了一种途径,可利用 TensorFlow 生态系统将 JAX 模型转换为设备端推断模型。此流程涉及两步转换:首先从 JAX 转换为 TensorFlow SavedModel,然后从 SavedModel 转换为 .tflite 格式。

转换过程

  1. 使用 jax2tf 将 JAX 转换为 TensorFlow SavedModel:第一步是将 JAX 模型转换为 TensorFlow SavedModel 格式。这是使用 jax2tf 工具完成的,该工具是一项实验性 JAX 实验性功能。 借助 jax2tf,您可以将 JAX 函数转换为 TensorFlow 图。

    如需查看有关如何使用 jax2tf 的详细说明和示例,请参阅官方 jax2tf 文档:https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    此过程通常涉及使用 jax2tf.convert 封装 JAX 模型的预测函数,然后使用 TensorFlow 的 tf.saved_model.save 保存该函数。

  2. 将 TensorFlow SavedModel 转换为 TFLite:当您的模型采用 TensorFlow SavedModel 格式时,您可以使用标准 TensorFlow Lite 转换器将其转换为 TFLite 格式。此流程可优化模型以在设备上执行,从而减小模型大小并提高性能。

    如需详细了解如何将 TensorFlow SavedModel 转换为 TFLite,请参阅 TensorFlow 模型转换指南

    本指南介绍了转换过程的各种选项和最佳实践,包括量化和其他优化。

按照这两个步骤操作,您就可以将使用 JAX 开发的模型高效地部署到边缘设备上,并使用 LiteRT 运行时。