Visão geral do suporte do JAX ao LiteRT

O LiteRT oferece um caminho para converter modelos do JAX para inferência no dispositivo aproveitando o ecossistema do TensorFlow. O processo envolve uma conversão em duas etapas: primeiro de JAX para SavedModel do TensorFlow e depois de SavedModel para o formato .tflite.

Processo de conversão

  1. JAX para SavedModel do TensorFlow usando jax2tf:a primeira etapa é converter seu modelo JAX no formato SavedModel do TensorFlow. Isso é feito usando a ferramenta jax2tf, que é um recurso experimental do JAX. O jax2tf permite converter funções JAX em gráficos do TensorFlow.

    Para instruções detalhadas e exemplos de como usar jax2tf, consulte a documentação oficial do jax2tf: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md

    Esse processo geralmente envolve encapsular a função de previsão do modelo JAX com jax2tf.convert e salvá-la usando tf.saved_model.save do TensorFlow.

  2. TensorFlow SavedModel para TFLite:depois de ter o modelo no formato TensorFlow SavedModel, é possível convertê-lo para o formato TFLite usando o conversor padrão do TensorFlow Lite. Esse processo otimiza o modelo para execução no dispositivo, reduzindo o tamanho e melhorando a performance.

    As instruções detalhadas para converter um SavedModel do TensorFlow em TFLite podem ser encontradas no guia de conversão de modelos do TensorFlow.

    Este guia aborda várias opções e práticas recomendadas para o processo de conversão, incluindo quantização e outras otimizações.

Ao seguir estas duas etapas, você pode usar os modelos desenvolvidos no JAX e implantá-los de maneira eficiente em dispositivos de borda usando o ambiente de execução LiteRT.