JAX-Modelle mit TensorFlow Lite

Diese Seite bietet einen Pfad für Nutzer, die Modelle in JAX trainieren und sie zur Inferenz auf Mobilgeräten bereitstellen möchten (Beispiel Colab.

Die Methoden in dieser Anleitung erzeugen einen tflite_model, der direkt mit dem Codebeispiel für den TFLite-Interpreter verwendet oder in einer TFLite FlatBuffer-Datei gespeichert werden kann.

Vorbereitung

Es wird empfohlen, dieses Feature mit dem neuesten nächtlichen Python-Paket von TensorFlow auszuprobieren.

pip install tf-nightly --upgrade

Zum Exportieren von JAX-Modellen verwenden wir die Orbax Export-Bibliothek. Stellen Sie sicher, dass Ihre JAX-Version 0.4.20 oder höher ist.

pip install jax --upgrade
pip install orbax-export --upgrade

JAX-Modelle in TensorFlow Lite konvertieren

Wir verwenden das SavedModel von TensorFlow als Zwischenformat zwischen JAX und TensorFlow Lite. Sobald Sie ein Modell haben, können Sie vorhandene TensorFlow Lite APIs verwenden, um den Konvertierungsprozess abzuschließen.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Konvertiertes TFLite-Modell prüfen

Nachdem das Modell in TFLite konvertiert wurde, können Sie die TFLite-Interpreter APIs ausführen, um die Modellausgaben zu prüfen.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])