This page provides a path for users who want to train models in JAX and deploy
to mobile for inference. The methods in this guide produce a tflite_model
which can be used directly with the LiteRT interpreter code example or saved to
a tflite FlatBuffer file.
We use the TensorFlow SavedModel
as the intermediate format between JAX and LiteRT. Once you have a SavedModel
then existing LiteRT APIs can be used to complete the conversion process.
# This code snippet converts a JAX model to TFLite through TF SavedModel.fromorbax.exportimportExportManagerfromorbax.exportimportJaxModulefromorbax.exportimportServingConfigimporttensorflowastfimportjax.numpyasjnpdefmodel_fn(_,x):returnjnp.sin(jnp.cos(x))jax_module=JaxModule({},model_fn,input_polymorphic_shape='b, ...')# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post# processing.tf.saved_model.save(jax_module,'/some/directory',signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input")),options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),)converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).serving_config=ServingConfig('Serving_default',# Corresponds to the input signature of `tf_preprocessor`input_signature=[tf.TensorSpec(shape=(None,),dtype=tf.float32,name='input')],tf_preprocessor=lambdax:x,tf_postprocessor=lambdaout:{'output':out})export_mgr=ExportManager(jax_module,[serving_config])export_mgr.save('/some/directory')converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 3: Convert from TF concrete function directlyconverter=tf.lite.TFLiteConverter.from_concrete_functions([jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input"))])tflite_model=converter.convert()
Check the converted TFLite model
After the model is converted to TFLite, you can run TFLite interpreter APIs to
check model outputs.
# Run the model with LiteRTinterpreter=tf.lite.Interpreter(model_content=tflite_model)interpreter.allocate_tensors()input_details=interpreter.get_input_details()output_details=interpreter.get_output_details()interpreter.set_tensor(input_details[0]["index"],input_data)interpreter.invoke()result=interpreter.get_tensor(output_details[0]["index"])
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-16 UTC."],[],[],null,["# JAX models with LiteRT\n\nThis page provides a path for users who want to train models in JAX and deploy\nto mobile for inference. The methods in this guide produce a `tflite_model`\nwhich can be used directly with the LiteRT interpreter code example or saved to\na `tflite` FlatBuffer file.\n\nFor an end-to-end example, see the [quickstart](./jax_to_tflite).\n\nPrerequisite\n------------\n\nIt's recommended to try this feature with the newest TensorFlow nightly Python\npackage. \n\n pip install tf-nightly --upgrade\n\nWe will use the [Orbax\nExport](https://orbax.readthedocs.io/en/latest/orbax_export_101.html) library to\nexport JAX models. Make sure your JAX version is at least 0.4.20 or above. \n\n pip install jax --upgrade\n pip install orbax-export --upgrade\n\nConvert JAX models to LiteRT\n----------------------------\n\nWe use the TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model)\nas the intermediate format between JAX and LiteRT. Once you have a SavedModel\nthen existing LiteRT APIs can be used to complete the conversion process. \n\n # This code snippet converts a JAX model to TFLite through TF SavedModel.\n from orbax.export import ExportManager\n from orbax.export import JaxModule\n from orbax.export import ServingConfig\n import tensorflow as tf\n import jax.numpy as jnp\n\n def model_fn(_, x):\n return jnp.sin(jnp.cos(x))\n\n jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')\n\n # Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post\n # processing.\n tf.saved_model.save(\n jax_module,\n '/some/directory',\n signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n ),\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),\n )\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).\n serving_config = ServingConfig(\n 'Serving_default',\n # Corresponds to the input signature of `tf_preprocessor`\n input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],\n tf_preprocessor=lambda x: x,\n tf_postprocessor=lambda out: {'output': out}\n )\n export_mgr = ExportManager(jax_module, [serving_config])\n export_mgr.save('/some/directory')\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 3: Convert from TF concrete function directly\n converter = tf.lite.TFLiteConverter.from_concrete_functions(\n [\n jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n )\n ]\n )\n tflite_model = converter.convert()\n\nCheck the converted TFLite model\n--------------------------------\n\nAfter the model is converted to TFLite, you can run TFLite interpreter APIs to\ncheck model outputs. \n\n # Run the model with LiteRT\n interpreter = tf.lite.Interpreter(model_content=tflite_model)\n interpreter.allocate_tensors() input_details = interpreter.get_input_details()\n output_details = interpreter.get_output_details()\n interpreter.set_tensor(input_details[0][\"index\"], input_data)\n interpreter.invoke()\n result = interpreter.get_tensor(output_details[0][\"index\"])"]]