LiteRT – JAX-Unterstützung – Übersicht

LiteRT bietet einen Pfad zum Konvertieren von JAX-Modellen für die Inferenz auf dem Gerät, indem das TensorFlow-Ökosystem genutzt wird. Der Prozess umfasst eine zweistufige Konvertierung: zuerst von JAX zu TensorFlow SavedModel und dann von SavedModel zum .tflite-Format.

Konvertierungsprozess

  1. JAX in TensorFlow SavedModel mit jax2tf:Im ersten Schritt müssen Sie Ihr JAX-Modell in das TensorFlow SavedModel-Format konvertieren. Dazu wird das jax2tf-Tool verwendet, eine experimentelle JAX-Funktion. Mit jax2tf können Sie JAX-Funktionen in TensorFlow-Diagramme konvertieren.

    Eine ausführliche Anleitung und Beispiele zur Verwendung von jax2tf finden Sie in der offiziellen jax2tf-Dokumentation: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md.

    Dazu müssen Sie in der Regel die Vorhersagefunktion Ihres JAX-Modells mit jax2tf.convert umschließen und dann mit tf.saved_model.save von TensorFlow speichern.

  2. TensorFlow SavedModel zu TFLite:Wenn Sie Ihr Modell im TensorFlow SavedModel-Format haben, können Sie es mit dem Standard-TensorFlow Lite-Converter in das TFLite-Format konvertieren. Bei diesem Prozess wird das Modell für die Ausführung auf dem Gerät optimiert, wodurch seine Größe reduziert und die Leistung verbessert wird.

    Eine detaillierte Anleitung zum Konvertieren eines TensorFlow-SavedModel in TFLite finden Sie in der TensorFlow-Anleitung zur Modellkonvertierung.

    In diesem Leitfaden werden verschiedene Optionen und Best Practices für den Konvertierungsprozess beschrieben, einschließlich Quantisierung und anderer Optimierungen.

Wenn Sie diese beiden Schritte ausführen, können Sie Ihre in JAX entwickelten Modelle effizient auf Edge-Geräten mit der LiteRT-Laufzeit bereitstellen.