Обзор поддержки LiteRT JAX

LiteRT предоставляет возможность конвертировать JAX-модели для вывода на устройстве, используя экосистему TensorFlow. Процесс включает два этапа: сначала преобразование из JAX в TensorFlow SavedModel, а затем из SavedModel в формат .tflite.

Процесс преобразования

  1. JAX в TensorFlow SavedModel с помощью jax2tf : Первый шаг — конвертировать JAX-модель в формат TensorFlow SavedModel. Это делается с помощью инструмента jax2tf , экспериментальной функции JAX. jax2tf позволяет конвертировать JAX-функции в графы TensorFlow.

    Подробные инструкции и примеры использования jax2tf можно найти в официальной документации jax2tf : https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    Этот процесс обычно включает в себя обертывание функции прогнозирования модели JAX с помощью jax2tf.convert и последующее ее сохранение с помощью tf.saved_model.save TensorFlow.

  2. TensorFlow SavedModel в TFLite: После преобразования модели в формат TensorFlow SavedModel вы можете конвертировать её в формат TFLite с помощью стандартного конвертера TensorFlow Lite. Этот процесс оптимизирует модель для выполнения на устройстве, уменьшая её размер и повышая производительность.

    Подробные инструкции по преобразованию TensorFlow SavedModel в TFLite можно найти в руководстве по преобразованию моделей TensorFlow .

    В этом руководстве рассматриваются различные варианты и передовые практики процесса преобразования, включая квантизацию и другие виды оптимизации.

Выполнив эти два шага, вы сможете эффективно развертывать модели, разработанные в JAX, на периферийных устройствах с использованием среды выполнения LiteRT.