# This code snippet converts a JAX model to TFLite through TF SavedModel.fromorbax.exportimportExportManagerfromorbax.exportimportJaxModulefromorbax.exportimportServingConfigimporttensorflowastfimportjax.numpyasjnpdefmodel_fn(_,x):returnjnp.sin(jnp.cos(x))jax_module=JaxModule({},model_fn,input_polymorphic_shape='b, ...')# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post# processing.tf.saved_model.save(jax_module,'/some/directory',signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input")),options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),)converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).serving_config=ServingConfig('Serving_default',# Corresponds to the input signature of `tf_preprocessor`input_signature=[tf.TensorSpec(shape=(None,),dtype=tf.float32,name='input')],tf_preprocessor=lambdax:x,tf_postprocessor=lambdaout:{'output':out})export_mgr=ExportManager(jax_module,[serving_config])export_mgr.save('/some/directory')converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 3: Convert from TF concrete function directlyconverter=tf.lite.TFLiteConverter.from_concrete_functions([jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input"))])tflite_model=converter.convert()
ตรวจสอบโมเดล TFLite ที่แปลงแล้ว
หลังจากแปลงโมเดลเป็น TFLite แล้ว คุณจะเรียกใช้ TFLite interpreter API เพื่อตรวจสอบเอาต์พุตของโมเดลได้
# Run the model with LiteRTinterpreter=tf.lite.Interpreter(model_content=tflite_model)interpreter.allocate_tensors()input_details=interpreter.get_input_details()output_details=interpreter.get_output_details()interpreter.set_tensor(input_details[0]["index"],input_data)interpreter.invoke()result=interpreter.get_tensor(output_details[0]["index"])
[[["เข้าใจง่าย","easyToUnderstand","thumb-up"],["แก้ปัญหาของฉันได้","solvedMyProblem","thumb-up"],["อื่นๆ","otherUp","thumb-up"]],[["ไม่มีข้อมูลที่ฉันต้องการ","missingTheInformationINeed","thumb-down"],["ซับซ้อนเกินไป/มีหลายขั้นตอนมากเกินไป","tooComplicatedTooManySteps","thumb-down"],["ล้าสมัย","outOfDate","thumb-down"],["ปัญหาเกี่ยวกับการแปล","translationIssue","thumb-down"],["ตัวอย่าง/ปัญหาเกี่ยวกับโค้ด","samplesCodeIssue","thumb-down"],["อื่นๆ","otherDown","thumb-down"]],["อัปเดตล่าสุด 2025-07-24 UTC"],[],[],null,["# JAX models with LiteRT\n\nThis page provides a path for users who want to train models in JAX and deploy\nto mobile for inference. The methods in this guide produce a `tflite_model`\nwhich can be used directly with the LiteRT interpreter code example or saved to\na `tflite` FlatBuffer file.\n\nFor an end-to-end example, see the [quickstart](./jax_to_tflite).\n\nPrerequisite\n------------\n\nIt's recommended to try this feature with the newest TensorFlow nightly Python\npackage. \n\n pip install tf-nightly --upgrade\n\nWe will use the [Orbax\nExport](https://orbax.readthedocs.io/en/latest/orbax_export_101.html) library to\nexport JAX models. Make sure your JAX version is at least 0.4.20 or above. \n\n pip install jax --upgrade\n pip install orbax-export --upgrade\n\nConvert JAX models to LiteRT\n----------------------------\n\nWe use the TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model)\nas the intermediate format between JAX and LiteRT. Once you have a SavedModel\nthen existing LiteRT APIs can be used to complete the conversion process. \n\n # This code snippet converts a JAX model to TFLite through TF SavedModel.\n from orbax.export import ExportManager\n from orbax.export import JaxModule\n from orbax.export import ServingConfig\n import tensorflow as tf\n import jax.numpy as jnp\n\n def model_fn(_, x):\n return jnp.sin(jnp.cos(x))\n\n jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')\n\n # Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post\n # processing.\n tf.saved_model.save(\n jax_module,\n '/some/directory',\n signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n ),\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),\n )\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).\n serving_config = ServingConfig(\n 'Serving_default',\n # Corresponds to the input signature of `tf_preprocessor`\n input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],\n tf_preprocessor=lambda x: x,\n tf_postprocessor=lambda out: {'output': out}\n )\n export_mgr = ExportManager(jax_module, [serving_config])\n export_mgr.save('/some/directory')\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 3: Convert from TF concrete function directly\n converter = tf.lite.TFLiteConverter.from_concrete_functions(\n [\n jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n )\n ]\n )\n tflite_model = converter.convert()\n\nCheck the converted TFLite model\n--------------------------------\n\nAfter the model is converted to TFLite, you can run TFLite interpreter APIs to\ncheck model outputs. \n\n # Run the model with LiteRT\n interpreter = tf.lite.Interpreter(model_content=tflite_model)\n interpreter.allocate_tensors() input_details = interpreter.get_input_details()\n output_details = interpreter.get_output_details()\n interpreter.set_tensor(input_details[0][\"index\"], input_data)\n interpreter.invoke()\n result = interpreter.get_tensor(output_details[0][\"index\"])"]]