LiteRT предоставляет возможность конвертировать JAX-модели для вывода на устройстве, используя экосистему TensorFlow. Процесс включает два этапа: сначала преобразование из JAX в TensorFlow SavedModel, а затем из SavedModel в формат .tflite.
Процесс преобразования
JAX в TensorFlow SavedModel с помощью
jax2tf: Первый шаг — конвертировать JAX-модель в формат TensorFlow SavedModel. Это делается с помощью инструментаjax2tf, экспериментальной функции JAX.jax2tfпозволяет конвертировать JAX-функции в графы TensorFlow.Подробные инструкции и примеры использования
jax2tfможно найти в официальной документацииjax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdЭтот процесс обычно включает в себя обертывание функции прогнозирования модели JAX с помощью
jax2tf.convertи последующее ее сохранение с помощьюtf.saved_model.saveTensorFlow.TensorFlow SavedModel в TFLite: После преобразования модели в формат TensorFlow SavedModel вы можете конвертировать её в формат TFLite с помощью стандартного конвертера TensorFlow Lite. Этот процесс оптимизирует модель для выполнения на устройстве, уменьшая её размер и повышая производительность.
Подробные инструкции по преобразованию TensorFlow SavedModel в TFLite можно найти в руководстве по преобразованию моделей TensorFlow .
В этом руководстве рассматриваются различные варианты и передовые практики процесса преобразования, включая квантизацию и другие виды оптимизации.
Выполнив эти два шага, вы сможете эффективно развертывать модели, разработанные в JAX, на периферийных устройствах с использованием среды выполнения LiteRT.