LiteRT는 TensorFlow 생태계를 활용하여 기기 내 추론을 위해 JAX 모델을 변환하는 경로를 제공합니다. 이 프로세스에는 2단계 변환이 포함됩니다. 먼저 JAX에서 TensorFlow SavedModel로 변환한 다음 SavedModel에서 .tflite 형식으로 변환합니다.
변환 프로세스
jax2tf을 사용한 JAX에서 TensorFlow SavedModel로 변환: 첫 번째 단계는 JAX 모델을 TensorFlow SavedModel 형식으로 변환하는 것입니다. 이는 실험용 JAX 실험 기능인jax2tf도구를 사용하여 실행됩니다.jax2tf를 사용하면 JAX 함수를 TensorFlow 그래프로 변환할 수 있습니다.jax2tf사용 방법에 관한 자세한 안내와 예는 공식jax2tf문서를 참고하세요. https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md이 프로세스에는 일반적으로 JAX 모델의 예측 함수를
jax2tf.convert로 래핑한 다음 TensorFlow의tf.saved_model.save를 사용하여 저장하는 작업이 포함됩니다.TensorFlow SavedModel을 TFLite로: 모델이 TensorFlow SavedModel 형식으로 되어 있으면 표준 TensorFlow Lite 변환기를 사용하여 TFLite 형식으로 변환할 수 있습니다. 이 프로세스는 온디바이스 실행을 위해 모델을 최적화하여 크기를 줄이고 성능을 개선합니다.
TensorFlow SavedModel을 TFLite로 변환하는 자세한 안내는 TensorFlow 모델 변환 가이드를 참고하세요.
이 가이드에서는 양자화 및 기타 최적화를 비롯한 변환 프로세스의 다양한 옵션과 권장사항을 다룹니다.
이 두 단계를 따르면 JAX로 개발된 모델을 LiteRT 런타임을 사용하여 에지 기기에 효율적으로 배포할 수 있습니다.