ภาพรวมการรองรับ JAX ของ LiteRT

LiteRT ช่วยให้คุณแปลงโมเดล JAX สำหรับการอนุมานในอุปกรณ์ได้โดย ใช้ประโยชน์จากระบบนิเวศของ TensorFlow กระบวนการนี้เกี่ยวข้องกับการแปลง 2 ขั้นตอน ได้แก่ ขั้นตอนแรกคือการแปลงจาก JAX เป็น TensorFlow SavedModel และขั้นตอนที่ 2 คือการแปลงจาก SavedModel เป็นรูปแบบ .tflite

กระบวนการ Conversion

  1. JAX เป็น TensorFlow SavedModel โดยใช้ jax2tf: ขั้นตอนแรกคือการ แปลงโมเดล JAX เป็นรูปแบบ TensorFlow SavedModel โดยทำได้โดยใช้เครื่องมือ jax2tf ซึ่งเป็นฟีเจอร์ทดลองของ JAX jax2tf ช่วยให้คุณแปลงฟังก์ชัน JAX เป็นกราฟ TensorFlow ได้

    ดูวิธีการและตัวอย่างโดยละเอียดเกี่ยวกับวิธีใช้ jax2tf ได้ที่ เอกสารประกอบอย่างเป็นทางการของ jax2tf https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    โดยปกติแล้ว กระบวนการนี้จะเกี่ยวข้องกับการห่อหุ้มฟังก์ชันการคาดการณ์ของโมเดล JAX ด้วย jax2tf.convert จากนั้นบันทึกโดยใช้ tf.saved_model.save ของ TensorFlow

  2. TensorFlow SavedModel เป็น TFLite: เมื่อมีโมเดลในรูปแบบ TensorFlow SavedModel แล้ว คุณจะแปลงเป็นรูปแบบ TFLite ได้โดยใช้ตัวแปลง TensorFlow Lite มาตรฐาน กระบวนการนี้จะเพิ่มประสิทธิภาพโมเดลสำหรับการ ดำเนินการในอุปกรณ์ ลดขนาด และปรับปรุงประสิทธิภาพ

    ดูวิธีการโดยละเอียดในการแปลง TensorFlow SavedModel เป็น TFLite ได้ในคู่มือการแปลงโมเดล TensorFlow

    คู่มือนี้ครอบคลุมตัวเลือกและแนวทางปฏิบัติแนะนำต่างๆ สำหรับกระบวนการ Conversion รวมถึงการหาปริมาณและการเพิ่มประสิทธิภาพอื่นๆ

เมื่อทำตาม 2 ขั้นตอนนี้ คุณจะนำโมเดลที่พัฒนาใน JAX ไปใช้ได้อย่างมีประสิทธิภาพในอุปกรณ์ Edge โดยใช้รันไทม์ LiteRT