หน้านี้มีเส้นทางสำหรับผู้ใช้ที่ต้องการฝึกโมเดลใน JAX และทำให้ใช้งานได้ บนอุปกรณ์เคลื่อนที่เพื่ออนุมาน (example colab
วิธีการในคู่มือนี้จะสร้าง tflite_model
ซึ่งใช้ได้โดยตรง
ด้วยตัวอย่างโค้ดล่าม TFLite หรือบันทึกไว้ในไฟล์ TFLite FlatBuffer
วิชาบังคับก่อน
ขอแนะนำให้ลองใช้ฟีเจอร์นี้กับ Python กลางคืน TensorFlow ใหม่ล่าสุด ใหม่
pip install tf-nightly --upgrade
เราจะใช้ฟังก์ชัน Orbax ส่งออกไลบรารีไปยัง ส่งออกโมเดล JAX ตรวจสอบว่า JAX เป็นเวอร์ชัน 0.4.20 ขึ้นไป
pip install jax --upgrade
pip install orbax-export --upgrade
แปลงโมเดล JAX เป็น LiteRT
เราใช้ TensorFlow SavedModel เป็นตัวกลาง ระหว่าง JAX และ LiteRT เมื่อคุณมี savedModel แล้ว คุณสามารถใช้ LiteRT API ที่มีอยู่เพื่อดำเนินกระบวนการแปลงให้เสร็จสมบูรณ์ได้
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
ตรวจสอบโมเดล TFLite ที่แปลงแล้ว
หลังจากแปลงโมเดลเป็น TFLite แล้ว คุณสามารถเรียกใช้ API ล่าม TFLite เพื่อ ตรวจสอบเอาต์พุตของโมเดล
# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])