TensorFlow Lite ile JAX modelleri

Bu sayfa, JAX'ta modelleri eğitmek ve çıkarım için mobil cihaza dağıtmak isteyen kullanıcılar için bir yol sunar (örnek colab.

Bu kılavuzdaki yöntemler, doğrudan TFLite çevirmen kodu örneğiyle kullanılabilecek veya bir TFLite FlatBuffer dosyasına kaydedilebilen bir tflite_model oluşturur.

Ön koşul

Bu özelliği, en yeni TensorFlow gecelik Python paketiyle denemeniz önerilir.

pip install tf-nightly --upgrade

JAX modellerini dışa aktarmak için Orbax Export kitaplığını kullanacağız. JAX sürümünüzün en az 0.4.20 veya üzeri olduğundan emin olun.

pip install jax --upgrade
pip install orbax-export --upgrade

JAX modellerini TensorFlow Lite'a dönüştürme

JAX ile TensorFlow Lite arasında ara biçim olarak TensorFlow SavedModel'i kullanıyoruz. Bir SavedModel'iniz varsa dönüştürme işlemini tamamlamak için mevcut TensorFlow Lite API'leri kullanılabilir.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Dönüştürülen TFLite modelini kontrol etme

Model TFLite'a dönüştürüldükten sonra, model çıkışlarını kontrol etmek için TFLite çevirmen API'lerini çalıştırabilirsiniz.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])