LiteRT fornisce un percorso per convertire i modelli JAX per l'inferenza on-device sfruttando l'ecosistema TensorFlow. Il processo prevede una conversione in due passaggi: prima da JAX a TensorFlow SavedModel e poi da SavedModel al formato .tflite.
Processo di conversione
JAX in TensorFlow SavedModel utilizzando
jax2tf: il primo passaggio consiste nel convertire il modello JAX nel formato TensorFlow SavedModel. Questa operazione viene eseguita utilizzando lo strumentojax2tf, che è una funzionalità sperimentale di JAX.jax2tfconsente di convertire le funzioni JAX in grafici TensorFlow.Per istruzioni ed esempi dettagliati su come utilizzare
jax2tf, consulta la documentazione ufficiale dijax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdIn genere, questo processo prevede il wrapping della funzione di previsione del modello JAX con
jax2tf.converte il salvataggio utilizzandotf.saved_model.savedi TensorFlow.TensorFlow SavedModel a TFLite: una volta che il modello è in formato TensorFlow SavedModel, puoi convertirlo in formato TFLite utilizzando il convertitore TensorFlow Lite standard. Questo processo ottimizza il modello per l'esecuzione sul dispositivo, riducendone le dimensioni e migliorandone le prestazioni.
Le istruzioni dettagliate per convertire un modello SavedModel di TensorFlow in TFLite sono disponibili nella guida alla conversione dei modelli TensorFlow.
Questa guida illustra varie opzioni e best practice per il processo di conversione, tra cui la quantizzazione e altre ottimizzazioni.
Seguendo questi due passaggi, puoi prendere i tuoi modelli sviluppati in JAX e implementarli in modo efficiente sui dispositivi edge utilizzando il runtime LiteRT.