Trang này cung cấp đường dẫn cho những người dùng muốn huấn luyện mô hình bằng JAX và triển khai
đến thiết bị di động để suy luận. Các phương thức trong hướng dẫn này tạo ra một tflite_model
có thể dùng trực tiếp với ví dụ về mã thông dịch LiteRT hoặc được lưu vào
tệp FlatBuffer tflite
.
Để xem ví dụ đầy đủ, hãy xem phần bắt đầu nhanh.
Điều kiện tiên quyết
Bạn nên dùng thử tính năng này với Python mới nhất để chạy vào ban đêm trên TensorFlow .
pip install tf-nightly --upgrade
Chúng ta sẽ sử dụng thư viện Orbax Export (Xuất Orbax) để xuất các mô hình JAX. Đảm bảo phiên bản JAX của bạn ít nhất là 0.4.20 trở lên.
pip install jax --upgrade
pip install orbax-export --upgrade
Chuyển đổi mô hình JAX sang LiteRT
Chúng tôi sử dụng SavedModel TensorFlow làm định dạng trung gian giữa JAX và LiteRT. Sau khi bạn có SavedModel thì bạn có thể sử dụng các API LiteRT hiện có để hoàn tất quá trình chuyển đổi.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
Kiểm tra mô hình TFLite đã chuyển đổi
Sau khi mô hình được chuyển đổi sang TFLite, bạn có thể chạy các API phiên dịch TFLite để kiểm tra đầu ra của mô hình.
# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])