توفّر هذه الصفحة مسارًا للمستخدمين الذين يريدون تدريب النماذج في JAX ونشرها.
إلى الجوّال للاستنتاج. تنتج عن الطرق الواردة في هذا الدليل tflite_model
والذي يمكن استخدامه مباشرةً مع مثال رمز المترجم الفوري LiteRT أو حفظه في
ملف FlatBuffer tflite.
للحصول على مثال شامل، يُرجى الاطّلاع على البدء السريع.
المتطلبات الأساسية
ننصحك بتجربة هذه الميزة مع أحدث إصدار من TensorFlow ليلاً للغة Python.
طرد.
pip install tf-nightly --upgrade
سنستخدم مكتبة Orbax
Export لملفعِد نماذج JAX. تأكَّد من أنّ إصدار JAX يعمل بالإصدار 0.4.20 على الأقل أو الإصدارات الأحدث.
نستخدم TensorFlow SavedModel
كتنسيق متوسط بين JAX وliteRT. بمجرد أن يكون لديك نموذج محفوظ
يمكن استخدام واجهات برمجة تطبيقات LiteRT الحالية لإكمال عملية التحويل.
# This code snippet converts a JAX model to TFLite through TF SavedModel.fromorbax.exportimportExportManagerfromorbax.exportimportJaxModulefromorbax.exportimportServingConfigimporttensorflowastfimportjax.numpyasjnpdefmodel_fn(_,x):returnjnp.sin(jnp.cos(x))jax_module=JaxModule({},model_fn,input_polymorphic_shape='b, ...')# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post# processing.tf.saved_model.save(jax_module,'/some/directory',signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input")),options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),)converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).serving_config=ServingConfig('Serving_default',# Corresponds to the input signature of `tf_preprocessor`input_signature=[tf.TensorSpec(shape=(None,),dtype=tf.float32,name='input')],tf_preprocessor=lambdax:x,tf_postprocessor=lambdaout:{'output':out})export_mgr=ExportManager(jax_module,[serving_config])export_mgr.save('/some/directory')converter=tf.lite.TFLiteConverter.from_saved_model('/some/directory')tflite_model=converter.convert()# Option 3: Convert from TF concrete function directlyconverter=tf.lite.TFLiteConverter.from_concrete_functions([jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(tf.TensorSpec(shape=(None,),dtype=tf.float32,name="input"))])tflite_model=converter.convert()
التحقّق من نموذج TFLite المُحوَّل
بعد تحويل النموذج إلى TFLite، يمكنك تشغيل واجهات برمجة التطبيقات الخاصة ببرنامج TFLite لتفسير النماذج بهدف
التحقّق من نواتج النموذج.
# Run the model with LiteRTinterpreter=tf.lite.Interpreter(model_content=tflite_model)interpreter.allocate_tensors()input_details=interpreter.get_input_details()output_details=interpreter.get_output_details()interpreter.set_tensor(input_details[0]["index"],input_data)interpreter.invoke()result=interpreter.get_tensor(output_details[0]["index"])
تاريخ التعديل الأخير: 2025-07-24 (حسب التوقيت العالمي المتفَّق عليه)
[[["يسهُل فهم المحتوى.","easyToUnderstand","thumb-up"],["ساعَدني المحتوى في حلّ مشكلتي.","solvedMyProblem","thumb-up"],["غير ذلك","otherUp","thumb-up"]],[["لا يحتوي على المعلومات التي أحتاج إليها.","missingTheInformationINeed","thumb-down"],["الخطوات معقدة للغاية / كثيرة جدًا.","tooComplicatedTooManySteps","thumb-down"],["المحتوى قديم.","outOfDate","thumb-down"],["ثمة مشكلة في الترجمة.","translationIssue","thumb-down"],["مشكلة في العيّنات / التعليمات البرمجية","samplesCodeIssue","thumb-down"],["غير ذلك","otherDown","thumb-down"]],["تاريخ التعديل الأخير: 2025-07-24 (حسب التوقيت العالمي المتفَّق عليه)"],[],[],null,["# JAX models with LiteRT\n\nThis page provides a path for users who want to train models in JAX and deploy\nto mobile for inference. The methods in this guide produce a `tflite_model`\nwhich can be used directly with the LiteRT interpreter code example or saved to\na `tflite` FlatBuffer file.\n\nFor an end-to-end example, see the [quickstart](./jax_to_tflite).\n\nPrerequisite\n------------\n\nIt's recommended to try this feature with the newest TensorFlow nightly Python\npackage. \n\n pip install tf-nightly --upgrade\n\nWe will use the [Orbax\nExport](https://orbax.readthedocs.io/en/latest/orbax_export_101.html) library to\nexport JAX models. Make sure your JAX version is at least 0.4.20 or above. \n\n pip install jax --upgrade\n pip install orbax-export --upgrade\n\nConvert JAX models to LiteRT\n----------------------------\n\nWe use the TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model)\nas the intermediate format between JAX and LiteRT. Once you have a SavedModel\nthen existing LiteRT APIs can be used to complete the conversion process. \n\n # This code snippet converts a JAX model to TFLite through TF SavedModel.\n from orbax.export import ExportManager\n from orbax.export import JaxModule\n from orbax.export import ServingConfig\n import tensorflow as tf\n import jax.numpy as jnp\n\n def model_fn(_, x):\n return jnp.sin(jnp.cos(x))\n\n jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')\n\n # Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post\n # processing.\n tf.saved_model.save(\n jax_module,\n '/some/directory',\n signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n ),\n options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),\n )\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).\n serving_config = ServingConfig(\n 'Serving_default',\n # Corresponds to the input signature of `tf_preprocessor`\n input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],\n tf_preprocessor=lambda x: x,\n tf_postprocessor=lambda out: {'output': out}\n )\n export_mgr = ExportManager(jax_module, [serving_config])\n export_mgr.save('/some/directory')\n converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')\n tflite_model = converter.convert()\n\n # Option 3: Convert from TF concrete function directly\n converter = tf.lite.TFLiteConverter.from_concrete_functions(\n [\n jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n tf.TensorSpec(shape=(None,), dtype=tf.float32, name=\"input\")\n )\n ]\n )\n tflite_model = converter.convert()\n\nCheck the converted TFLite model\n--------------------------------\n\nAfter the model is converted to TFLite, you can run TFLite interpreter APIs to\ncheck model outputs. \n\n # Run the model with LiteRT\n interpreter = tf.lite.Interpreter(model_content=tflite_model)\n interpreter.allocate_tensors() input_details = interpreter.get_input_details()\n output_details = interpreter.get_output_details()\n interpreter.set_tensor(input_details[0][\"index\"], input_data)\n interpreter.invoke()\n result = interpreter.get_tensor(output_details[0][\"index\"])"]]