Halaman ini menyediakan jalur bagi pengguna yang ingin melatih model di JAX dan men-deploy ke perangkat seluler untuk inferensi (contoh colab.
Metode dalam panduan ini menghasilkan tflite_model
yang dapat digunakan langsung
dengan contoh kode penafsir TFLite atau disimpan ke file TFLite FlatBuffer.
Prasyarat
Sebaiknya coba fitur ini dengan paket Python malam TensorFlow terbaru.
pip install tf-nightly --upgrade
Kita akan menggunakan library Orbax Export untuk mengekspor model JAX. Pastikan versi JAX Anda minimal 0.4.20 atau yang lebih baru.
pip install jax --upgrade
pip install orbax-export --upgrade
Mengonversi model JAX ke TensorFlow Lite
Kami menggunakan TensorFlow SavedModel sebagai format menengah antara JAX dan TensorFlow Lite. Setelah Anda memilikiSavedModel, API TensorFlow Lite yang sudah ada dapat digunakan untuk menyelesaikan proses konversi.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
Memeriksa model TFLite yang dikonversi
Setelah model dikonversi ke TFLite, Anda dapat menjalankan TFLite interpreter API untuk memeriksa output model.
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])