Modelet JAX me LiteRT

Kjo faqe ofron një shteg për përdoruesit që duan të trajnojnë modelet në JAX dhe të vendosin në celular për konkluzione. Metodat në këtë udhëzues prodhojnë një tflite_model i cili mund të përdoret drejtpërdrejt me shembullin e kodit të interpretuesit LiteRT ose të ruhet në një skedar tflite FlatBuffer.

Për një shembull nga fundi në fund, shihni fillimin e shpejtë .

Kusht paraprak

Rekomandohet ta provoni këtë veçori me paketën më të re të Python të natës TensorFlow.

pip install tf-nightly --upgrade

Ne do të përdorim bibliotekën Orbax Export për të eksportuar modele JAX. Sigurohuni që versioni juaj JAX të jetë të paktën 0.4.20 ose më lart.

pip install jax --upgrade
pip install orbax-export --upgrade

Konvertoni modelet JAX në LiteRT

Ne përdorim TensorFlow SavedModel si format të ndërmjetëm midis JAX dhe LiteRT. Pasi të keni një SavedModel, atëherë API-të ekzistuese LiteRT mund të përdoren për të përfunduar procesin e konvertimit.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Kontrolloni modelin e konvertuar TFLite

Pasi modeli të konvertohet në TFLite, mund të ekzekutoni API-të e interpretit TFLite për të kontrolluar rezultatet e modelit.

# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])