JAX-Modelle mit LiteRT

Diese Seite bietet einen Pfad für Nutzer, die Modelle in JAX trainieren und für die Inferenz an Mobilgeräte (Beispiel Colab.

Mit den Methoden in diesem Leitfaden wird ein tflite_model erstellt, das direkt verwendet werden kann. mit dem TFLite-Interpreter-Codebeispiel enthalten oder in einer TFLite FlatBuffer-Datei gespeichert.

Vorbereitung

Es wird empfohlen, dieses Feature mit der neuesten nächtlichen TensorFlow-Version von Python zu testen Paket.

pip install tf-nightly --upgrade

Wir verwenden Orbax Bibliothek nach exportieren JAX-Modelle exportieren können. Stellen Sie sicher, dass Ihre JAX-Version mindestens 0.4.20 ist.

pip install jax --upgrade
pip install orbax-export --upgrade

JAX-Modelle in LiteRT umwandeln

Wir verwenden die TensorFlow- SavedModel als Zwischenziel zwischen JAX und LiteRT. Wenn Sie ein SavedModel haben, Vorhandene LiteRT APIs können verwendet werden, um den Konvertierungsprozess abzuschließen.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Konvertiertes TFLite-Modell prüfen

Nachdem das Modell in TFLite konvertiert wurde, können Sie TFLite Interpreter APIs ausführen, um Modellausgaben zu prüfen.

# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])