Bu sayfa, modelleri JAX'ta eğitmek ve dağıtmak isteyen kullanıcılar için
uyarlayabilirsiniz. Bu kılavuzdaki yöntemler bir tflite_model
oluşturur
Bu kod, doğrudan LiteRT çevirmen kodu örneğiyle kullanılabilir veya
bir tflite
FlatBuffer dosyası ekleyebilirsiniz.
Uçtan uca bir örnek için hızlı başlangıç sayfasına göz atın.
Ön koşul
Bu özelliği en yeni TensorFlow gecelik Python paketiyle denemeniz önerilir.
pip install tf-nightly --upgrade
Bunun için Orbax Kitaplığı şu klasöre aktarın: dışa aktarabilirsiniz. JAX sürümünüzün en az 0.4.20 veya üzeri olduğundan emin olun.
pip install jax --upgrade
pip install orbax-export --upgrade
JAX modellerini LiteRT'e dönüştürme
JAX ile LiteRT arasındaki ara biçim olarak TensorFlow SavedModel'i kullanırız. SavedModel'iniz olduğunda dönüşüm sürecini tamamlamak için mevcut LiteRT API'leri kullanılabilir.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
Dönüştürülmüş TFLite modelini kontrol etme
Model TFLite'a dönüştürüldükten sonra aşağıdaki işlemler için TFLite çevirmen API'lerini çalıştırabilirsiniz: model çıkışlarını kontrol edin.
# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])