طُرز JAX مع TensorFlow Lite

توفّر هذه الصفحة مسارًا للمستخدمين الذين يريدون تدريب النماذج في JAX ونشرها على الأجهزة الجوّالة من أجل الاستنتاج (مثال على colab.

تُنتج الطرق الواردة في هذا الدليل tflite_model يمكن استخدامه مباشرةً باستخدام مثال الرمز البرمجي الخاص بمترجم TFLite أو حفظه في ملف TFLite FlatBuffer.

المتطلبات الأساسية

ويُوصى بتجربة هذه الميزة مع أحدث حزمة من برامج البايثون الليلية من TensorFlow.

pip install tf-nightly --upgrade

سنستخدم مكتبة Orbax Export لتصدير نماذج JAX. تأكَّد من أنّ إصدار JAX لا يقلّ عن 0.4.20 أو إصدار أحدث.

pip install jax --upgrade
pip install orbax-export --upgrade

تحويل نماذج JAX إلى TensorFlow Lite

نستخدم نموذج TensorFlow SavedModel كتنسيق متوسط بين JAX وTensorFlow Lite. بعد أن يكون لديك SaveModel، يمكن استخدام واجهات برمجة تطبيقات TensorFlow Lite الحالية لإكمال عملية التحويل.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

التحقّق من نموذج TFLite المحوَّل

بعد تحويل النموذج إلى نموذج TFLite، يمكنك تشغيل واجهات برمجة تطبيقات TFLite للمترجم للتحقّق من مخرجات النموذج.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])