LiteRT מספקת דרך להמיר מודלים של JAX להסקת מסקנות במכשיר באמצעות מערכת TensorFlow. התהליך כולל המרה בשני שלבים: קודם מ-JAX ל-TensorFlow SavedModel, ואז מ-SavedModel לפורמט .tflite.
תהליך ההמרה
JAX ל-TensorFlow SavedModel באמצעות
jax2tf: השלב הראשון הוא להמיר את מודל JAX לפורמט TensorFlow SavedModel. הפעולה הזו מתבצעת באמצעות הכליjax2tf, שהוא תכונה ניסיונית של JAX. jax2tfמאפשרת להמיר פונקציות JAX לגרפים של TensorFlow.הוראות מפורטות ודוגמאות לשימוש ב-
jax2tfזמינות במסמכי התיעוד הרשמיים שלjax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdבדרך כלל התהליך הזה כולל עטיפה של פונקציית החיזוי של מודל JAX באמצעות
jax2tf.convertושמירה שלה באמצעותtf.saved_model.saveשל TensorFlow.TensorFlow SavedModel ל-TFLite: אחרי שהמודל שלכם בפורמט TensorFlow SavedModel, אתם יכולים להמיר אותו לפורמט TFLite באמצעות כלי ההמרה הרגיל של TensorFlow Lite. התהליך הזה מבצע אופטימיזציה של המודל לביצוע במכשיר, מקטין את הגודל שלו ומשפר את הביצועים.
הוראות מפורטות להמרת TensorFlow SavedModel ל-TFLite זמינות במדריך להמרת מודלים של TensorFlow.
במדריך הזה נסקור אפשרויות שונות ושיטות מומלצות לתהליך ההמרה, כולל קוונטיזציה ואופטימיזציות אחרות.
אם תפעלו לפי שני השלבים האלה, תוכלו לקחת את המודלים שפיתחתם ב-JAX ולפרוס אותם ביעילות במכשירי קצה באמצעות זמן הריצה של LiteRT.