LiteRT با بهرهگیری از اکوسیستم TensorFlow مسیری را برای تبدیل مدلهای JAX برای استنتاج روی دستگاه فراهم میکند. این فرآیند شامل یک تبدیل دو مرحلهای است: ابتدا از JAX به TensorFlow SavedModel و سپس از SavedModel به فرمت .tflite.
فرآیند تبدیل
تبدیل JAX به TensorFlow SavedModel با استفاده از
jax2tf: اولین قدم تبدیل مدل JAX شما به فرمت TensorFlow SavedModel است. این کار با استفاده از ابزارjax2tfانجام میشود که یک ویژگی آزمایشی JAX است.jax2tfبه شما امکان میدهد توابع JAX را به نمودارهای TensorFlow تبدیل کنید.برای دستورالعملهای دقیق و مثالهایی در مورد نحوه استفاده
jax2tf، لطفاً به مستندات رسمیjax2tfمراجعه کنید: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.mdاین فرآیند معمولاً شامل قرار دادن تابع پیشبینی مدل JAX شما در
jax2tf.convertو سپس ذخیره آن با استفاده ازtf.saved_model.saveدر TensorFlow است.تبدیل مدل ذخیرهشده TensorFlow به TFLite: وقتی مدل خود را در قالب مدل ذخیرهشده TensorFlow داشتید، میتوانید آن را با استفاده از مبدل استاندارد TensorFlow Lite به قالب TFLite تبدیل کنید. این فرآیند، مدل را برای اجرا روی دستگاه بهینه میکند، اندازه آن را کاهش میدهد و عملکرد را بهبود میبخشد.
دستورالعملهای دقیق برای تبدیل TensorFlow SavedModel به TFLite را میتوانید در راهنمای تبدیل TensorFlow Model بیابید.
این راهنما گزینههای مختلف و بهترین شیوهها برای فرآیند تبدیل، از جمله کوانتیزاسیون و سایر بهینهسازیها را پوشش میدهد.
با دنبال کردن این دو مرحله، میتوانید مدلهای توسعهیافته خود را در JAX گرفته و با استفاده از زمان اجرای LiteRT، آنها را به طور کارآمد روی دستگاههای لبهای مستقر کنید.