LiteRT menyediakan jalur untuk mengonversi model JAX untuk inferensi di perangkat dengan memanfaatkan ekosistem TensorFlow. Proses ini melibatkan konversi dua langkah: pertama dari JAX ke TensorFlow SavedModel, lalu dari SavedModel ke format .tflite.
Proses Konversi
JAX ke TensorFlow SavedModel menggunakan
jax2tf: Langkah pertama adalah mengonversi model JAX Anda ke format TensorFlow SavedModel. Hal ini dilakukan menggunakan alatjax2tf, yang merupakan fitur eksperimental JAX.jax2tfmemungkinkan Anda mengonversi fungsi JAX menjadi grafik TensorFlow.Untuk petunjuk dan contoh mendetail tentang cara menggunakan
jax2tf, lihat dokumentasijax2tfresmi: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdProses ini biasanya melibatkan pembungkusan fungsi prediksi model JAX Anda dengan
jax2tf.convert, lalu menyimpannya menggunakantf.saved_model.saveTensorFlow.TensorFlow SavedModel ke TFLite: Setelah memiliki model dalam format TensorFlow SavedModel, Anda dapat mengonversinya ke format TFLite menggunakan konverter TensorFlow Lite standar. Proses ini mengoptimalkan model untuk eksekusi di perangkat, mengurangi ukurannya, dan meningkatkan performa.
Petunjuk mendetail untuk mengonversi SavedModel TensorFlow ke TFLite dapat ditemukan di panduan Konversi model TensorFlow.
Panduan ini mencakup berbagai opsi dan praktik terbaik untuk proses konversi, termasuk kuantisasi dan pengoptimalan lainnya.
Dengan mengikuti dua langkah ini, Anda dapat menggunakan model yang dikembangkan di JAX dan men-deploy-nya secara efisien di perangkat edge menggunakan runtime LiteRT.