Trang này cung cấp đường dẫn cho những người dùng muốn huấn luyện mô hình bằng JAX và triển khai
đến thiết bị di động để suy luận. Các phương thức trong hướng dẫn này tạo ra một tflite_model
có thể dùng trực tiếp với ví dụ về mã thông dịch LiteRT hoặc được lưu vào
tệp FlatBuffer tflite.
Chúng tôi sử dụng SavedModel TensorFlow
làm định dạng trung gian giữa JAX và LiteRT. Sau khi bạn có SavedModel
thì bạn có thể sử dụng các API LiteRT hiện có để hoàn tất quá trình chuyển đổi.
# This code snippet converts a JAX model to TFLite through TF SavedModel.fromorbax.exportimportExportManagerfromorbax.exportimportJaxModulefromorbax.exportimportServingConfigimporttensorflowastfimportjax.numpyasjnpdefmodel_fn(_,x):returnjnp.sin(jnp.cos(x))jax_module=JaxModule({},model_fn,input_polymorphic_shape='b, ...')# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post# processing.tf.saved_model.save(jax_module,'/some/directory',signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input")),options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),)converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).serving_config=ServingConfig('Serving_default',# Corresponds to the input signature of `tf_preprocessor`input_signature=[tf.TensorSpec(shape=(None,),dtype=tf.float32,name='input')],tf_preprocessor=lambdax:x,tf_postprocessor=lambdaout:{'output':out})export_mgr=ExportManager(jax_module,[serving_config])export_mgr.save('/some/directory')converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 3: Convert from TF concrete function directlyconverter=tf.lite.TFLiteConverter.from_concrete_functions([jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input"))])tflite_model=converter.convert()
Kiểm tra mô hình TFLite đã chuyển đổi
Sau khi mô hình được chuyển đổi sang TFLite, bạn có thể chạy các API phiên dịch TFLite để
kiểm tra đầu ra của mô hình.
# Run the model with LiteRTinterpreter=tf.lite.Interpreter(model_content=tflite_model)interpreter.allocate_tensors()input_details=interpreter.get_input_details()output_details=interpreter.get_output_details()interpreter.set_tensor(input_details[0]["index"],input_data)interpreter.invoke()result=interpreter.get_tensor(output_details[0]["index"])
[[["Dễ hiểu","easyToUnderstand","thumb-up"],["Giúp tôi giải quyết được vấn đề","solvedMyProblem","thumb-up"],["Khác","otherUp","thumb-up"]],[["Thiếu thông tin tôi cần","missingTheInformationINeed","thumb-down"],["Quá phức tạp/quá nhiều bước","tooComplicatedTooManySteps","thumb-down"],["Đã lỗi thời","outOfDate","thumb-down"],["Vấn đề về bản dịch","translationIssue","thumb-down"],["Vấn đề về mẫu/mã","samplesCodeIssue","thumb-down"],["Khác","otherDown","thumb-down"]],["Cập nhật lần gần đây nhất: 2025-07-24 UTC."],[],[],null,["# JAX models with LiteRT\n\nThis page provides a path for users who want to train models in JAX and deploy\nto mobile for inference. The methods in this guide produce a `tflite_model`\nwhich can be used directly with the LiteRT interpreter code example or saved to\na `tflite` FlatBuffer file.\n\nFor an end-to-end example, see the [quickstart](./jax_to_tflite).\n\nPrerequisite\n------------\n\nIt's recommended to try this feature with the newest TensorFlow nightly Python\npackage. \n\n pip install tf-nightly --upgrade\n\nWe will use the [Orbax\nExport](https://orbax.readthedocs.io/en/latest/orbax_export_101.html) library to\nexport JAX models. Make sure your JAX version is at least 0.4.20 or above. \n\n pip install jax --upgrade\n pip install orbax-export --upgrade\n\nConvert JAX models to LiteRT\n----------------------------\n\nWe use the TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model)\nas the intermediate format between JAX and LiteRT. Once you have a SavedModel\nthen existing LiteRT APIs can be used to complete the conversion process. \n\n # This code snippet converts a JAX model to TFLite through TF SavedModel.\n from orbax.export import ExportManager\n from orbax.export import JaxModule\n from orbax.export import ServingConfig\n import tensorflow as tf\n import jax.numpy as jnp\n\n def model_fn(_, x):\n return jnp.sin(jnp.cos(x))\n\n jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')\n\n # Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post\n # processing.\n tf.saved_model.save(\n jax_module,\n '/some/directory',\n signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n ),\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),\n )\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).\n serving_config = ServingConfig(\n 'Serving_default',\n # Corresponds to the input signature of `tf_preprocessor`\n input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],\n tf_preprocessor=lambda x: x,\n tf_postprocessor=lambda out: {'output': out}\n )\n export_mgr = ExportManager(jax_module, [serving_config])\n export_mgr.save('/some/directory')\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 3: Convert from TF concrete function directly\n converter = tf.lite.TFLiteConverter.from_concrete_functions(\n [\n jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n )\n ]\n )\n tflite_model = converter.convert()\n\nCheck the converted TFLite model\n--------------------------------\n\nAfter the model is converted to TFLite, you can run TFLite interpreter APIs to\ncheck model outputs. \n\n # Run the model with LiteRT\n interpreter = tf.lite.Interpreter(model_content=tflite_model)\n interpreter.allocate_tensors() input_details = interpreter.get_input_details()\n output_details = interpreter.get_output_details()\n interpreter.set_tensor(input_details[0][\"index\"], input_data)\n interpreter.invoke()\n result = interpreter.get_tensor(output_details[0][\"index\"])"]]