Omówienie obsługi LiteRT JAX

LiteRT umożliwia konwertowanie modeli JAX na potrzeby wnioskowania na urządzeniu przez wykorzystanie ekosystemu TensorFlow. Proces ten obejmuje 2-etapową konwersję: najpierw z JAX na TensorFlow SavedModel, a potem z SavedModel na format .tflite.

Proces konwersji

  1. JAX do TensorFlow SavedModel za pomocą jax2tf: pierwszym krokiem jest przekonwertowanie modelu JAX na format TensorFlow SavedModel. Odbywa się to za pomocą narzędzia jax2tf, które jest eksperymentalną funkcją JAX. jax2tf umożliwia przekształcanie funkcji JAX w wykresy TensorFlow.

    Szczegółowe instrukcje i przykłady użycia jax2tf znajdziesz w oficjalnej dokumentacji jax2tf:https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    Zwykle polega to na opakowaniu funkcji prognozowania modelu JAX za pomocą funkcji jax2tf.convert, a następnie zapisaniu jej za pomocą funkcji tf.saved_model.save TensorFlow.

  2. TensorFlow SavedModel do TFLite: gdy masz już model w formacie TensorFlow SavedModel, możesz przekonwertować go na format TFLite za pomocą standardowego konwertera TensorFlow Lite. Ten proces optymalizuje model pod kątem wykonywania na urządzeniu, zmniejszając jego rozmiar i zwiększając wydajność.

    Szczegółowe instrukcje konwertowania modelu TensorFlow SavedModel na TFLite znajdziesz w przewodniku po konwersji modeli TensorFlow.

    Ten przewodnik zawiera różne opcje i sprawdzone metody konwersji, w tym kwantyzację i inne optymalizacje.

Wykonując te 2 kroki, możesz efektywnie wdrażać na urządzeniach brzegowych modele opracowane w JAX przy użyciu środowiska wykonawczego LiteRT.