TensorFlow Lite를 사용한 JAX 모델

이 페이지에서는 JAX에서 모델을 학습시키고 추론을 위해 모바일에 배포하려는 사용자를 위한 경로를 제공합니다 (Colab 예시).

이 가이드의 메서드는 TFLite 인터프리터 코드 예와 함께 직접 사용하거나 TFLite FlatBuffer 파일에 저장할 수 있는 tflite_model를 생성합니다.

기본 요건

최신 TensorFlow 나이틀리 Python 패키지와 함께 이 기능을 사용해 보는 것이 좋습니다.

pip install tf-nightly --upgrade

Orbax Export 라이브러리를 사용하여 JAX 모델을 내보냅니다. JAX 버전이 0.4.20 이상인지 확인합니다.

pip install jax --upgrade
pip install orbax-export --upgrade

JAX 모델을 TensorFlow Lite로 변환

Google에서는 JAX와 TensorFlow Lite 간의 중간 형식으로 TensorFlow SavedModel을 사용합니다. 저장된 모델이 있으면 기존 TensorFlow Lite API를 사용하여 변환 프로세스를 완료할 수 있습니다.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

변환된 TFLite 모델 확인

모델이 TFLite로 변환된 후에는 TFLite 인터프리터 API를 실행하여 모델 출력을 확인할 수 있습니다.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])