LiteRT ile JAX modelleri

Bu sayfa, modelleri JAX'ta eğitmek ve dağıtmak isteyen kullanıcılar için uyarlayabilirsiniz. Bu kılavuzdaki yöntemler bir tflite_model oluşturur Bu kod, doğrudan LiteRT çevirmen kodu örneğiyle kullanılabilir veya bir tflite FlatBuffer dosyası ekleyebilirsiniz.

Uçtan uca bir örnek için hızlı başlangıç sayfasına göz atın.

Ön koşul

Bu özelliği en yeni TensorFlow gecelik Python paketiyle denemeniz önerilir.

pip install tf-nightly --upgrade

Bunun için Orbax Kitaplığı şu klasöre aktarın: dışa aktarabilirsiniz. JAX sürümünüzün en az 0.4.20 veya üzeri olduğundan emin olun.

pip install jax --upgrade
pip install orbax-export --upgrade

JAX modellerini LiteRT'e dönüştürme

JAX ile LiteRT arasındaki ara biçim olarak TensorFlow SavedModel'i kullanırız. SavedModel'iniz olduğunda dönüşüm sürecini tamamlamak için mevcut LiteRT API'leri kullanılabilir.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Dönüştürülmüş TFLite modelini kontrol etme

Model TFLite'a dönüştürüldükten sonra aşağıdaki işlemler için TFLite çevirmen API'lerini çalıştırabilirsiniz: model çıkışlarını kontrol edin.

# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])