Présentation de la compatibilité LiteRT JAX

LiteRT permet de convertir des modèles JAX pour l'inférence sur l'appareil en tirant parti de l'écosystème TensorFlow. Le processus implique une conversion en deux étapes : d'abord de JAX vers TensorFlow SavedModel, puis de SavedModel vers le format .tflite.

Processus de conversion

  1. JAX vers TensorFlow SavedModel à l'aide de jax2tf : la première étape consiste à convertir votre modèle JAX au format TensorFlow SavedModel. Pour ce faire, nous utilisons l'outil jax2tf, qui est une fonctionnalité expérimentale de JAX. jax2tf vous permet de convertir des fonctions JAX en graphiques TensorFlow.

    Pour obtenir des instructions détaillées et des exemples sur l'utilisation de jax2tf, veuillez consulter la documentation officielle de jax2tf : https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md.

    Ce processus implique généralement d'encapsuler la fonction de prédiction de votre modèle JAX avec jax2tf.convert, puis de l'enregistrer à l'aide de tf.saved_model.save de TensorFlow.

  2. TensorFlow SavedModel vers TFLite : une fois que votre modèle est au format TensorFlow SavedModel, vous pouvez le convertir au format TFLite à l'aide du convertisseur TensorFlow Lite standard. Ce processus optimise le modèle pour l'exécution sur l'appareil, en réduisant sa taille et en améliorant ses performances.

    Vous trouverez des instructions détaillées pour convertir un modèle SavedModel TensorFlow en modèle TFLite dans le guide de conversion de modèle TensorFlow.

    Ce guide aborde différentes options et bonnes pratiques pour le processus de conversion, y compris la quantification et d'autres optimisations.

En suivant ces deux étapes, vous pouvez prendre vos modèles développés dans JAX et les déployer efficacement sur des appareils de périphérie à l'aide du runtime LiteRT.