यह पेज उन उपयोगकर्ताओं के लिए पाथ उपलब्ध कराता है जो JAX में मॉडल को ट्रेनिंग देना चाहते हैं और डिप्लॉय करना चाहते हैं मोबाइल का इस्तेमाल करें (उदाहरण के लिए colab.
इस गाइड में दिए गए तरीके से एक tflite_model
बनता है. इसका इस्तेमाल सीधे तौर पर किया जा सकता है
को TFLite अनुवादक कोड के उदाहरण के साथ या TFLite FlatBuffer फ़ाइल में सेव किया गया.
पूर्वापेक्षा
हमारा सुझाव है कि आप सबसे नए TensorFlow नाइटली Python के साथ इस सुविधा को आज़माएं पैकेज.
pip install tf-nightly --upgrade
हम Orbax लाइब्रेरी को इसमें एक्सपोर्ट करें JAX मॉडल एक्सपोर्ट करें. पक्का करें कि JAX वर्शन कम से कम 0.4.20 या इसके बाद का हो.
pip install jax --upgrade
pip install orbax-export --upgrade
JAX मॉडल को LiteRT में बदलें
हम TensorFlow का इस्तेमाल करते हैं SavedModel को इंटरमीडिएट के तौर पर इस्तेमाल करना का फ़ॉर्मैट JAX और LiteRT के बीच होना चाहिए. जब आपके पासSaved मॉडल हो, तो कन्वर्ज़न प्रोसेस को पूरा करने के लिए, मौजूदा LiteRT API का इस्तेमाल किया जा सकता है.
# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp
def model_fn(_, x):
return jnp.sin(jnp.cos(x))
jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')
# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
jax_module,
'/some/directory',
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
'Serving_default',
# Corresponds to the input signature of `tf_preprocessor`
input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
tf_preprocessor=lambda x: x,
tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()
# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
)
]
)
tflite_model = converter.convert()
बदले गए TFLite मॉडल की जांच करें
मॉडल को TFLite में बदलने के बाद, TFLite अनुवादक एपीआई चलाया जा सकता है मॉडल आउटपुट की जाँच करें.
# Run the model with LiteRT
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])