LiteRT provides a path to convert JAX models for on-device inference by leveraging the TensorFlow ecosystem. The process involves a two-step conversion: first from JAX to TensorFlow SavedModel, and then from SavedModel to the .tflite format.
Conversion Process
JAX to TensorFlow SavedModel using
jax2tf: The first step is to convert your JAX model into the TensorFlow SavedModel format. This is done using thejax2tftool, which is an experimental JAX experimental feature.jax2tfallows you to convert JAX functions into TensorFlow graphs.For detailed instructions and examples on how to use
jax2tf, please refer to the officialjax2tfdocumentation: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdThis process will typically involve wrapping your JAX model's prediction function with
jax2tf.convertand then saving it using TensorFlow'stf.saved_model.save.TensorFlow SavedModel to TFLite: Once you have your model in the TensorFlow SavedModel format, you can convert it to the TFLite format using the standard TensorFlow Lite converter. This process optimizes the model for on-device execution, reducing its size and improving performance.
The detailed instructions for converting a TensorFlow SavedModel to TFLite can be found in the TensorFlow Model conversion guide.
This guide covers various options and best practices for the conversion process, including quantization and other optimizations.
By following these two steps, you can take your models developed in JAX and deploy them efficiently on edge devices using the LiteRT runtime.