Panoramica del supporto di LiteRT JAX

LiteRT fornisce un percorso per convertire i modelli JAX per l'inferenza on-device sfruttando l'ecosistema TensorFlow. Il processo prevede una conversione in due passaggi: prima da JAX a TensorFlow SavedModel e poi da SavedModel al formato .tflite.

Processo di conversione

  1. JAX in TensorFlow SavedModel utilizzando jax2tf: il primo passaggio consiste nel convertire il modello JAX nel formato TensorFlow SavedModel. Questa operazione viene eseguita utilizzando lo strumento jax2tf, che è una funzionalità sperimentale di JAX. jax2tf consente di convertire le funzioni JAX in grafici TensorFlow.

    Per istruzioni ed esempi dettagliati su come utilizzare jax2tf, consulta la documentazione ufficiale di jax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    In genere, questo processo prevede il wrapping della funzione di previsione del modello JAX con jax2tf.convert e il salvataggio utilizzando tf.saved_model.save di TensorFlow.

  2. TensorFlow SavedModel a TFLite: una volta che il modello è in formato TensorFlow SavedModel, puoi convertirlo in formato TFLite utilizzando il convertitore TensorFlow Lite standard. Questo processo ottimizza il modello per l'esecuzione sul dispositivo, riducendone le dimensioni e migliorandone le prestazioni.

    Le istruzioni dettagliate per convertire un modello SavedModel di TensorFlow in TFLite sono disponibili nella guida alla conversione dei modelli TensorFlow.

    Questa guida illustra varie opzioni e best practice per il processo di conversione, tra cui la quantizzazione e altre ottimizzazioni.

Seguendo questi due passaggi, puoi prendere i tuoi modelli sviluppati in JAX e implementarli in modo efficiente sui dispositivi edge utilizzando il runtime LiteRT.