LiteRT cung cấp một đường dẫn để chuyển đổi các mô hình JAX cho suy luận trên thiết bị bằng cách tận dụng hệ sinh thái TensorFlow. Quá trình này bao gồm hai bước chuyển đổi: đầu tiên là từ JAX sang TensorFlow SavedModel, sau đó là từ SavedModel sang định dạng .tflite.
Quy trình chuyển đổi
JAX sang TensorFlow SavedModel bằng
jax2tf: Bước đầu tiên là chuyển đổi mô hình JAX của bạn sang định dạng TensorFlow SavedModel. Việc này được thực hiện bằng công cụjax2tf, đây là một tính năng thử nghiệm của JAX.jax2tfcho phép bạn chuyển đổi các hàm JAX thành biểu đồ TensorFlow.Để biết hướng dẫn và ví dụ chi tiết về cách sử dụng
jax2tf, vui lòng tham khảo tài liệu chính thức vềjax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdQuá trình này thường liên quan đến việc bao bọc hàm dự đoán của mô hình JAX bằng
jax2tf.convertrồi lưu hàm đó bằngtf.saved_model.savecủa TensorFlow.TensorFlow SavedModel sang TFLite: Sau khi có mô hình ở định dạng TensorFlow SavedModel, bạn có thể chuyển đổi mô hình đó sang định dạng TFLite bằng trình chuyển đổi TensorFlow Lite tiêu chuẩn. Quá trình này tối ưu hoá mô hình để thực thi trên thiết bị, giảm kích thước và cải thiện hiệu suất.
Bạn có thể xem hướng dẫn chi tiết về cách chuyển đổi TensorFlow SavedModel sang TFLite trong hướng dẫn Chuyển đổi mô hình TensorFlow.
Hướng dẫn này trình bày nhiều lựa chọn và phương pháp hay nhất cho quy trình chuyển đổi, bao gồm cả việc lượng tử hoá và các phương pháp tối ưu hoá khác.
Bằng cách làm theo hai bước này, bạn có thể lấy các mô hình đã phát triển trong JAX và triển khai chúng một cách hiệu quả trên các thiết bị biên bằng cách sử dụng thời gian chạy LiteRT.