Tổng quan về tính năng hỗ trợ JAX của LiteRT

LiteRT cung cấp một đường dẫn để chuyển đổi các mô hình JAX cho suy luận trên thiết bị bằng cách tận dụng hệ sinh thái TensorFlow. Quá trình này bao gồm hai bước chuyển đổi: đầu tiên là từ JAX sang TensorFlow SavedModel, sau đó là từ SavedModel sang định dạng .tflite.

Quy trình chuyển đổi

  1. JAX sang TensorFlow SavedModel bằng jax2tf: Bước đầu tiên là chuyển đổi mô hình JAX của bạn sang định dạng TensorFlow SavedModel. Việc này được thực hiện bằng công cụ jax2tf, đây là một tính năng thử nghiệm của JAX. jax2tf cho phép bạn chuyển đổi các hàm JAX thành biểu đồ TensorFlow.

    Để biết hướng dẫn và ví dụ chi tiết về cách sử dụng jax2tf, vui lòng tham khảo tài liệu chính thức về jax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    Quá trình này thường liên quan đến việc bao bọc hàm dự đoán của mô hình JAX bằng jax2tf.convert rồi lưu hàm đó bằng tf.saved_model.save của TensorFlow.

  2. TensorFlow SavedModel sang TFLite: Sau khi có mô hình ở định dạng TensorFlow SavedModel, bạn có thể chuyển đổi mô hình đó sang định dạng TFLite bằng trình chuyển đổi TensorFlow Lite tiêu chuẩn. Quá trình này tối ưu hoá mô hình để thực thi trên thiết bị, giảm kích thước và cải thiện hiệu suất.

    Bạn có thể xem hướng dẫn chi tiết về cách chuyển đổi TensorFlow SavedModel sang TFLite trong hướng dẫn Chuyển đổi mô hình TensorFlow.

    Hướng dẫn này trình bày nhiều lựa chọn và phương pháp hay nhất cho quy trình chuyển đổi, bao gồm cả việc lượng tử hoá và các phương pháp tối ưu hoá khác.

Bằng cách làm theo hai bước này, bạn có thể lấy các mô hình đã phát triển trong JAX và triển khai chúng một cách hiệu quả trên các thiết bị biên bằng cách sử dụng thời gian chạy LiteRT.