LiteRT umożliwia konwertowanie modeli JAX na potrzeby wnioskowania na urządzeniu przez wykorzystanie ekosystemu TensorFlow. Proces ten obejmuje 2-etapową konwersję: najpierw z JAX na TensorFlow SavedModel, a potem z SavedModel na format .tflite.
Proces konwersji
JAX do TensorFlow SavedModel za pomocą
jax2tf: pierwszym krokiem jest przekonwertowanie modelu JAX na format TensorFlow SavedModel. Odbywa się to za pomocą narzędziajax2tf, które jest eksperymentalną funkcją JAX.jax2tfumożliwia przekształcanie funkcji JAX w wykresy TensorFlow.Szczegółowe instrukcje i przykłady użycia
jax2tfznajdziesz w oficjalnej dokumentacjijax2tf:https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdZwykle polega to na opakowaniu funkcji prognozowania modelu JAX za pomocą funkcji
jax2tf.convert, a następnie zapisaniu jej za pomocą funkcjitf.saved_model.saveTensorFlow.TensorFlow SavedModel do TFLite: gdy masz już model w formacie TensorFlow SavedModel, możesz przekonwertować go na format TFLite za pomocą standardowego konwertera TensorFlow Lite. Ten proces optymalizuje model pod kątem wykonywania na urządzeniu, zmniejszając jego rozmiar i zwiększając wydajność.
Szczegółowe instrukcje konwertowania modelu TensorFlow SavedModel na TFLite znajdziesz w przewodniku po konwersji modeli TensorFlow.
Ten przewodnik zawiera różne opcje i sprawdzone metody konwersji, w tym kwantyzację i inne optymalizacje.
Wykonując te 2 kroki, możesz efektywnie wdrażać na urządzeniach brzegowych modele opracowane w JAX przy użyciu środowiska wykonawczego LiteRT.