LiteRT JAX サポートの概要

LiteRT は、TensorFlow エコシステムを活用して、オンデバイス推論用に JAX モデルを変換するパスを提供します。このプロセスでは、まず JAX から TensorFlow SavedModel に変換し、次に SavedModel から .tflite 形式に変換するという 2 段階の変換が行われます。

コンバージョン プロセス

  1. jax2tf を使用した JAX から TensorFlow SavedModel への変換: 最初のステップは、JAX モデルを TensorFlow SavedModel 形式に変換することです。これは、試験運用版の JAX 試験運用機能である jax2tf ツールを使用して行われます。jax2tf を使用すると、JAX 関数を TensorFlow グラフに変換できます。

    jax2tf の使用方法に関する詳細な手順と例については、jax2tf の公式ドキュメント(https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)をご覧ください。

    通常、このプロセスでは、JAX モデルの予測関数を jax2tf.convert でラップし、TensorFlow の tf.saved_model.save を使用して保存します。

  2. TensorFlow SavedModel から TFLite: モデルが TensorFlow SavedModel 形式になったら、標準の TensorFlow Lite コンバータを使用して TFLite 形式に変換できます。このプロセスでは、モデルをデバイス上での実行用に最適化し、サイズを縮小してパフォーマンスを向上させます。

    TensorFlow SavedModel を TFLite に変換する詳細な手順については、TensorFlow モデル変換ガイドをご覧ください。

    このガイドでは、量子化などの最適化を含む、変換プロセスのさまざまなオプションとベスト プラクティスについて説明します。

この 2 つの手順を行うことで、JAX で開発したモデルを LiteRT ランタイムを使用してエッジデバイスに効率的にデプロイできます。