TensorFlow Lite के साथ JAX मॉडल

इस पेज पर उन लोगों के लिए पाथ दिया गया है जो JAX में मॉडल को ट्रेनिंग देना चाहते हैं और अनुमान लगाने के लिए मोबाइल पर डिप्लॉय करना चाहते हैं (उदाहरण के लिए, कोलैब.

इस गाइड में दिए गए तरीके से एक tflite_model मिलता है. इसका इस्तेमाल सीधे TFLite इंटरप्रेटर कोड के उदाहरण के साथ किया जा सकता है या TFLite फ़्लैटबफ़र फ़ाइल में सेव किया जा सकता है.

पहले से आवश्यक

हमारा सुझाव है कि आप इस सुविधा को, हर रात के लिए TensorFlow के नए Python पैकेज के साथ आज़माएं.

pip install tf-nightly --upgrade

हम JAX मॉडल एक्सपोर्ट करने के लिए, Orbax Export लाइब्रेरी का इस्तेमाल करेंगे. पक्का करें कि आपका JAX वर्शन कम से कम 0.4.20 या उसके बाद का हो.

pip install jax --upgrade
pip install orbax-export --upgrade

JAX मॉडल को TensorFlow Lite में बदलें

हम JAX और TensorFlow Lite के बीच, इंटरमीडिएट फ़ॉर्मैट के तौर पर TensorFlow SavedModel का इस्तेमाल करते हैं. सेव किए गए मॉडल का इस्तेमाल करने के बाद, कन्वर्ज़न की प्रोसेस को पूरा करने के लिए मौजूदा TensorFlow Lite एपीआई का इस्तेमाल किया जा सकता है.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

बदला गया TFLite मॉडल देखें

मॉडल के TFLite में बदलने के बाद, आपके पास TFLite इंटरप्रेटर एपीआई चलाने का विकल्प होता है. इससे, मॉडल के आउटपुट की जांच की जा सकती है.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])