نمای کلی پشتیبانی LiteRT JAX

LiteRT با بهره‌گیری از اکوسیستم TensorFlow مسیری را برای تبدیل مدل‌های JAX برای استنتاج روی دستگاه فراهم می‌کند. این فرآیند شامل یک تبدیل دو مرحله‌ای است: ابتدا از JAX به TensorFlow SavedModel و سپس از SavedModel به فرمت .tflite.

فرآیند تبدیل

  1. تبدیل JAX به TensorFlow SavedModel با استفاده از jax2tf : اولین قدم تبدیل مدل JAX شما به فرمت TensorFlow SavedModel است. این کار با استفاده از ابزار jax2tf انجام می‌شود که یک ویژگی آزمایشی JAX است. jax2tf به شما امکان می‌دهد توابع JAX را به نمودارهای TensorFlow تبدیل کنید.

    برای دستورالعمل‌های دقیق و مثال‌هایی در مورد نحوه استفاده jax2tf ، لطفاً به مستندات رسمی jax2tf مراجعه کنید: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    این فرآیند معمولاً شامل قرار دادن تابع پیش‌بینی مدل JAX شما در jax2tf.convert و سپس ذخیره آن با استفاده از tf.saved_model.save در TensorFlow است.

  2. تبدیل مدل ذخیره‌شده TensorFlow به TFLite: وقتی مدل خود را در قالب مدل ذخیره‌شده TensorFlow داشتید، می‌توانید آن را با استفاده از مبدل استاندارد TensorFlow Lite به قالب TFLite تبدیل کنید. این فرآیند، مدل را برای اجرا روی دستگاه بهینه می‌کند، اندازه آن را کاهش می‌دهد و عملکرد را بهبود می‌بخشد.

    دستورالعمل‌های دقیق برای تبدیل TensorFlow SavedModel به TFLite را می‌توانید در راهنمای تبدیل TensorFlow Model بیابید.

    این راهنما گزینه‌های مختلف و بهترین شیوه‌ها برای فرآیند تبدیل، از جمله کوانتیزاسیون و سایر بهینه‌سازی‌ها را پوشش می‌دهد.

با دنبال کردن این دو مرحله، می‌توانید مدل‌های توسعه‌یافته خود را در JAX گرفته و با استفاده از زمان اجرای LiteRT، آنها را به طور کارآمد روی دستگاه‌های لبه‌ای مستقر کنید.