Descripción general de la compatibilidad de LiteRT con JAX

LiteRT proporciona una ruta para convertir modelos de JAX para la inferencia en el dispositivo aprovechando el ecosistema de TensorFlow. El proceso implica una conversión de dos pasos: primero, de JAX a SavedModel de TensorFlow y, luego, de SavedModel al formato .tflite.

Proceso de conversión

  1. De JAX a TensorFlow SavedModel con jax2tf: El primer paso es convertir tu modelo de JAX al formato de TensorFlow SavedModel. Esto se hace con la herramienta jax2tf, que es una función experimental de JAX. jax2tf te permite convertir funciones de JAX en grafos de TensorFlow.

    Para obtener instrucciones y ejemplos detallados sobre cómo usar jax2tf, consulta la documentación oficial de jax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    Por lo general, este proceso implicará encapsular la función de predicción de tu modelo de JAX con jax2tf.convert y, luego, guardarla con tf.saved_model.save de TensorFlow.

  2. Modelo guardado de TensorFlow a TFLite: Una vez que tengas tu modelo en el formato de modelo guardado de TensorFlow, puedes convertirlo al formato de TFLite con el conversor estándar de TensorFlow Lite. Este proceso optimiza el modelo para la ejecución en el dispositivo, lo que reduce su tamaño y mejora el rendimiento.

    Las instrucciones detalladas para convertir un modelo guardado de TensorFlow a TFLite se pueden encontrar en la guía de conversión de modelos de TensorFlow.

    En esta guía, se abordan varias opciones y prácticas recomendadas para el proceso de conversión, incluida la cuantización y otras optimizaciones.

Si sigues estos dos pasos, podrás tomar los modelos que desarrollaste en JAX y, luego, implementarlos de manera eficiente en dispositivos perimetrales con el entorno de ejecución de LiteRT.