LiteRT bietet einen Pfad zum Konvertieren von JAX-Modellen für die Inferenz auf dem Gerät, indem das TensorFlow-Ökosystem genutzt wird. Der Prozess umfasst eine zweistufige Konvertierung: zuerst von JAX zu TensorFlow SavedModel und dann von SavedModel zum .tflite-Format.
Konvertierungsprozess
JAX in TensorFlow SavedModel mit
jax2tf:Im ersten Schritt müssen Sie Ihr JAX-Modell in das TensorFlow SavedModel-Format konvertieren. Dazu wird dasjax2tf-Tool verwendet, eine experimentelle JAX-Funktion. Mitjax2tfkönnen Sie JAX-Funktionen in TensorFlow-Diagramme konvertieren.Eine ausführliche Anleitung und Beispiele zur Verwendung von
jax2tffinden Sie in der offiziellenjax2tf-Dokumentation: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md.Dazu müssen Sie in der Regel die Vorhersagefunktion Ihres JAX-Modells mit
jax2tf.convertumschließen und dann mittf.saved_model.savevon TensorFlow speichern.TensorFlow SavedModel zu TFLite:Wenn Sie Ihr Modell im TensorFlow SavedModel-Format haben, können Sie es mit dem Standard-TensorFlow Lite-Converter in das TFLite-Format konvertieren. Bei diesem Prozess wird das Modell für die Ausführung auf dem Gerät optimiert, wodurch seine Größe reduziert und die Leistung verbessert wird.
Eine detaillierte Anleitung zum Konvertieren eines TensorFlow-SavedModel in TFLite finden Sie in der TensorFlow-Anleitung zur Modellkonvertierung.
In diesem Leitfaden werden verschiedene Optionen und Best Practices für den Konvertierungsprozess beschrieben, einschließlich Quantisierung und anderer Optimierungen.
Wenn Sie diese beiden Schritte ausführen, können Sie Ihre in JAX entwickelten Modelle effizient auf Edge-Geräten mit der LiteRT-Laufzeit bereitstellen.