Modèles JAX avec TensorFlow Lite

Cette page propose un parcours aux utilisateurs qui souhaitent entraîner des modèles dans JAX et les déployer sur mobile à des fins d'inférence (exemple Colab).

Les méthodes de ce guide génèrent un tflite_model qui peut être utilisé directement avec l'exemple de code de l'interpréteur TFLite ou enregistré dans un fichier FlatBuffer TFLite.

Prérequis

Nous vous recommandons d'essayer cette fonctionnalité avec le tout dernier package Python nocturne TensorFlow.

pip install tf-nightly --upgrade

Nous utiliserons la bibliothèque Orbax Export pour exporter des modèles JAX. Assurez-vous de disposer de la version 0.4.20 ou ultérieure de JAX.

pip install jax --upgrade
pip install orbax-export --upgrade

Convertir des modèles JAX en modèles TensorFlow Lite

Nous utilisons le SavedModel TensorFlow comme format intermédiaire entre JAX et TensorFlow Lite. Une fois que vous disposez d'un SavedModel, les API TensorFlow Lite existantes peuvent être utilisées pour terminer le processus de conversion.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Vérifier le modèle TFLite converti

Une fois le modèle converti au format TFLite, vous pouvez exécuter les API de l'interpréteur TFLite pour vérifier les sorties du modèle.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])