طُرز JAX التي تستخدم LiteRT

توفّر هذه الصفحة مسارًا للمستخدمين الذين يريدون تدريب النماذج في JAX ونشرها. إلى الجوّال للاستنتاج. تنتج عن الطرق الواردة في هذا الدليل tflite_model والذي يمكن استخدامه مباشرةً مع مثال رمز المترجم الفوري LiteRT أو حفظه في ملف FlatBuffer tflite.

للحصول على مثال شامل، يُرجى الاطّلاع على البدء السريع.

المتطلبات الأساسية

ننصحك بتجربة هذه الميزة مع أحدث إصدار من TensorFlow ليلاً للغة Python. طرد.

pip install tf-nightly --upgrade

سنستخدم مكتبة Orbax Export لملفعِد نماذج JAX. تأكَّد من أنّ إصدار JAX يعمل بالإصدار 0.4.20 على الأقل أو الإصدارات الأحدث.

pip install jax --upgrade
pip install orbax-export --upgrade

تحويل نماذج JAX إلى LiteRT

نستخدم TensorFlow SavedModel كتنسيق متوسط بين JAX وliteRT. بمجرد أن يكون لديك نموذج محفوظ يمكن استخدام واجهات برمجة تطبيقات LiteRT الحالية لإكمال عملية التحويل.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

التحقّق من نموذج TFLite المُحوَّل

بعد تحويل النموذج إلى TFLite، يمكنك تشغيل واجهات برمجة التطبيقات الخاصة ببرنامج TFLite لتفسير النماذج بهدف التحقّق من نواتج النموذج.

# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])