En esta página, se proporciona una ruta para los usuarios que desean entrenar modelos en JAX y, luego, implementarlos en dispositivos móviles para la inferencia. Los métodos de esta guía producen un tflite_model
que puede usarse directamente con el ejemplo de código de intérprete de LiteRT o guardarse en
un archivo FlatBuffer tflite
.
Para ver un ejemplo de extremo a extremo, consulta la guía de inicio rápido.
Requisitos
Se recomienda probar esta función con la versión más reciente de TensorFlow .
pip install tf-nightly --upgrade
Usaremos la biblioteca Orbax Export para exportar modelos de JAX. Asegúrate de que tu versión de JAX sea, al menos, 0.4.20 o una posterior.
pip install jax --upgrade
pip install orbax-export --upgrade
Convierte modelos de JAX a LiteRT
Usamos el SavedModel de TensorFlow. como formato intermedio entre JAX y LiteRT. Una vez que tengas un modelo guardado se pueden usar las APIs de LiteRT existentes para completar el proceso de conversión.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
Verifica el modelo de TFLite convertido
Después de convertir el modelo a TFLite, puedes ejecutar las APIs de intérprete de TFLite para verificar los resultados del modelo.
# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])