Modelos JAX con TensorFlow Lite

En esta página, se proporciona una ruta para los usuarios que desean entrenar modelos en JAX y, luego, implementarlos en dispositivos móviles para realizar inferencias (ejemplo de Colab.

Los métodos de esta guía producen un tflite_model que se puede usar directamente con el ejemplo de código del intérprete de TFLite o guardar en un archivo FlatBuffer de TFLite.

Requisitos

Se recomienda probar esta función con el paquete nocturno de Python de TensorFlow más reciente.

pip install tf-nightly --upgrade

Usaremos la biblioteca de Orbax Export para exportar modelos de JAX. Asegúrate de que tu versión de JAX sea, al menos, 0.4.20 o posterior.

pip install jax --upgrade
pip install orbax-export --upgrade

Convierte modelos de JAX a TensorFlow Lite

Usamos el SavedModel de TensorFlow como el formato intermedio entre JAX y TensorFlow Lite. Una vez que tengas un modelo guardado, podrás usar las APIs de TensorFlow Lite existentes para completar el proceso de conversión.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Verifica el modelo de TFLite convertido

Después de convertir el modelo en TFLite, puedes ejecutar las APIs de intérprete de TFLite para verificar los resultados del modelo.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])