Bu sayfa, JAX'ta modelleri eğitmek ve çıkarım için mobil cihaza dağıtmak isteyen kullanıcılar için bir yol sunar (örnek colab.
Bu kılavuzdaki yöntemler, doğrudan TFLite çevirmen kodu örneğiyle kullanılabilecek veya bir TFLite FlatBuffer dosyasına kaydedilebilen bir tflite_model
oluşturur.
Ön koşul
Bu özelliği, en yeni TensorFlow gecelik Python paketiyle denemeniz önerilir.
pip install tf-nightly --upgrade
JAX modellerini dışa aktarmak için Orbax Export kitaplığını kullanacağız. JAX sürümünüzün en az 0.4.20 veya üzeri olduğundan emin olun.
pip install jax --upgrade
pip install orbax-export --upgrade
JAX modellerini TensorFlow Lite'a dönüştürme
JAX ile TensorFlow Lite arasında ara biçim olarak TensorFlow SavedModel'i kullanıyoruz. Bir SavedModel'iniz varsa dönüştürme işlemini tamamlamak için mevcut TensorFlow Lite API'leri kullanılabilir.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
Dönüştürülen TFLite modelini kontrol etme
Model TFLite'a dönüştürüldükten sonra, model çıkışlarını kontrol etmek için TFLite çevirmen API'lerini çalıştırabilirsiniz.
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])