Modèles JAX avec LiteRT

Cette page fournit un chemin d'accès aux utilisateurs qui souhaitent entraîner des modèles en JAX et les déployer au mobile à des fins d'inférence. Les méthodes décrites dans ce guide génèrent un tflite_model qui peut être utilisé directement avec l'exemple de code d'interpréteur LiteRT ou enregistré dans un fichier FlatBuffer tflite.

Pour obtenir un exemple de bout en bout, consultez le guide de démarrage rapide.

Conditions préalables

Nous vous recommandons d'essayer cette fonctionnalité avec la dernière version nocturne de TensorFlow d'un package.

pip install tf-nightly --upgrade

Nous allons utiliser le modèle Orbax Exportez la bibliothèque vers exporter des modèles JAX. Assurez-vous que la version de JAX est au moins la version 0.4.20.

pip install jax --upgrade
pip install orbax-export --upgrade

Convertir des modèles JAX au format LiteRT

Nous utilisons le SavedModel TensorFlow comme format intermédiaire entre JAX et LiteRT. Une fois que vous disposez d'un SavedModel les API LiteRT existantes peuvent alors être utilisées pour terminer le processus de conversion.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Vérifier le modèle TFLite converti

Une fois le modèle converti en TFLite, vous pouvez exécuter des API d'interpréteur TFLite pour vérifier les sorties du modèle.

# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])