דגמים של JAX עם TensorFlow Lite

דף זה מספק נתיב למשתמשים שרוצים לאמן מודלים ב-JAX ולפרוס מודלים ניידים לצורך הסקת מסקנות (colab לדוגמה.

השיטות במדריך הזה יוצרות tflite_model שאפשר להשתמש בו ישירות עם הדוגמה של קוד תרגום TFLite או לשמור אותו בקובץ TFLite FlatBuffer.

ידע מוקדם שנדרש לקורס

מומלץ לנסות את התכונה הזו עם חבילת Python בלילה החדשה ביותר של TensorFlow.

pip install tf-nightly --upgrade

אנחנו נשתמש בספרייה Orbax Export כדי לייצא מודלים של JAX. חשוב לוודא שהגרסה של JAX היא 0.4.20 ואילך.

pip install jax --upgrade
pip install orbax-export --upgrade

המרת מודלים של JAX ל-TensorFlow Lite

אנחנו משתמשים ב-TensorFlow SavedModel כפורמט הביניים בין JAX ל-TensorFlow Lite. אחרי שיוצרים מודל SaveModel, אפשר להשתמש בממשקי ה-API של TensorFlow Lite כדי להשלים את תהליך ההמרה.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

בדיקת מודל TFLite שעבר המרה

אחרי ההמרה של המודל ל-TFLite, תוכלו להריץ את ממשקי ה-API של תרגום TFLite כדי לבדוק את הפלט של המודל.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])