סקירה כללית על תמיכה ב-LiteRT JAX

‫LiteRT מספקת דרך להמיר מודלים של JAX להסקת מסקנות במכשיר באמצעות מערכת TensorFlow. התהליך כולל המרה בשני שלבים: קודם מ-JAX ל-TensorFlow SavedModel, ואז מ-SavedModel לפורמט ‎ .tflite.

תהליך ההמרה

  1. JAX ל-TensorFlow SavedModel באמצעות jax2tf: השלב הראשון הוא להמיר את מודל JAX לפורמט TensorFlow SavedModel. הפעולה הזו מתבצעת באמצעות הכלי jax2tf, שהוא תכונה ניסיונית של JAX. ‫jax2tf מאפשרת להמיר פונקציות JAX לגרפים של TensorFlow.

    הוראות מפורטות ודוגמאות לשימוש ב-jax2tf זמינות במסמכי התיעוד הרשמיים של jax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    בדרך כלל התהליך הזה כולל עטיפה של פונקציית החיזוי של מודל JAX באמצעות jax2tf.convert ושמירה שלה באמצעות tf.saved_model.save של TensorFlow.

  2. TensorFlow SavedModel ל-TFLite: אחרי שהמודל שלכם בפורמט TensorFlow SavedModel, אתם יכולים להמיר אותו לפורמט TFLite באמצעות כלי ההמרה הרגיל של TensorFlow Lite. התהליך הזה מבצע אופטימיזציה של המודל לביצוע במכשיר, מקטין את הגודל שלו ומשפר את הביצועים.

    הוראות מפורטות להמרת TensorFlow SavedModel ל-TFLite זמינות במדריך להמרת מודלים של TensorFlow.

    במדריך הזה נסקור אפשרויות שונות ושיטות מומלצות לתהליך ההמרה, כולל קוונטיזציה ואופטימיזציות אחרות.

אם תפעלו לפי שני השלבים האלה, תוכלו לקחת את המודלים שפיתחתם ב-JAX ולפרוס אותם ביעילות במכשירי קצה באמצעות זמן הריצה של LiteRT.