LiteRT para Web com LiteRT.js

O LiteRT.js é o ambiente de execução WebAI de alta performance do Google, voltado para aplicativos da Web de produção. Ele é uma continuação da pilha LiteRT, garantindo suporte a vários frameworks e unificando nosso ambiente de execução principal em todas as plataformas.

O LiteRT.js oferece suporte aos seguintes recursos principais:

  1. Inferência acelerada por hardware no navegador: execute modelos com desempenho excepcional da CPU, acelerado pelo XNNPack mapeado para WebAssembly (Wasm) leve. Para escalonamento de GPU e hardware dedicado (como NPUs), o LiteRT.js mostra nativamente a WebGPU e a API WebNN emergente, permitindo uma otimização específica da plataforma.
  2. Compatibilidade com vários frameworks: simplifique a semântica de desenvolvimento compilando nativamente seu framework de ML preferido: PyTorch, JAX ou TensorFlow.
  3. Iterar em pipelines atuais: integração pronta para uso com arquiteturas do TensorFlow.js atuais, analisando tensores do TensorFlow.js com suporte nativo como entradas e saídas de limite direto.

Instalação

Instale o pacote @litertjs/core do npm:

npm install @litertjs/core

Os arquivos Wasm estão localizados em node_modules/@litertjs/core/wasm/. Para sua conveniência, copie e veicule toda a pasta wasm/. Em seguida, importe o pacote e carregue os arquivos Wasm:

import {loadLiteRt} from '@litertjs/core';

// Load the LiteRT.js Wasm files from a CDN.
await loadLiteRt('https://cdn.jsdelivr.net/npm/@litertjs/core/wasm/')
// Alternatively, host them from your server.
// They are located in node_modules/@litertjs/core/wasm/
await loadLiteRt(`your/path/to/wasm/`);

Conversão de modelos

O LiteRT.js usa o mesmo formato .tflite que o restante do ecossistema LiteRT, e oferece suporte a modelos atuais no Kaggle e Huggingface. Se você tiver um novo modelo do PyTorch, será necessário convertê-lo.

Converter um modelo do PyTorch para o LiteRT

Para converter um modelo do PyTorch para o LiteRT, use o litert-torch conversor.

import litert_torch

# Load your torch model. We're using resnet for this example.
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1)

sample_inputs = (torch.randn(1, 3, 224, 224),)

# Convert the model to LiteRT.
edge_model = litert_torch.convert(resnet18.eval(), sample_inputs)

# Export the model.
edge_model.export('resnet.tflite')

Executar o modelo convertido

Depois de converter o modelo em um arquivo .tflite, você poderá executá-lo no navegador.

import {loadAndCompile} from '@litertjs/core';

// Load the model hosted from your server. This makes an http(s) request.
const model = await loadAndCompile('/path/to/model.tflite', {
    accelerator: 'webgpu',
    // Can select from 'webnn', 'webgpu', & 'wasm'.
    // Additionally, you can pass an array of accelerators e.g. ['webnn', 'wasm']
    // if you would like to fallback to CPU execution,
    // Note that ONLY cpu fallback is supported for now
    // (i.e. specifying ['webnn', 'webgpu']) will lead to compilation errors
});
// The model can also be loaded from a Uint8Array if you want to fetch it yourself.

// Create image input data
const image = new Float32Array(224 * 224 * 3).fill(0);
const inputTensor = new Tensor(image, /* shape */ [1, 3, 224, 224]);

// Run the model
const outputs = await model.run(inputTensor);
// You can also use `await model.run([inputTensor]);`
// or `await model.run({'input_tensor_name': inputTensor});`

// Clean up and get outputs
inputTensor.delete();
const output = outputs[0];
const outputData = await output.data();
output.delete();

Integrar em pipelines do TensorFlow.js

Considere integrar o LiteRT.js aos TensorFlow.js pelos seguintes motivos:

  1. Desempenho excepcional de GPU e hardware: os modelos do LiteRT.js aproveitam a aceleração do WebGPU para otimizar o desempenho em arquiteturas de navegador. Com suporte ao WebGPU e ao WebNN, o LiteRT.js oferece aceleração de hardware flexível em vários dispositivos de borda.
  2. Caminho de conversão de modelo mais fácil: o caminho de conversão do LiteRT.js vai diretamente do PyTorch para o LiteRT. O caminho de conversão do PyTorch para o TensorFlow.js é muito mais complicado, exigindo que você vá do PyTorch -> ONNX -> TensorFlow -> TensorFlow.js.
  3. Ferramentas de depuração: o caminho de conversão do LiteRT.js vem com ferramentas de depuração.

O LiteRT.js foi projetado para funcionar em pipelines do TensorFlow.js e é compatível com o pré-processamento e pós-processamento do TensorFlow.js. Portanto, a única coisa que você precisa migrar é o modelo.

Integre o LiteRT.js aos pipelines do TensorFlow.js seguindo estas etapas:

  1. Converta seu modelo original do TensorFlow, JAX ou PyTorch para .tflite. Para mais detalhes, consulte a seção de conversão de modelos.
  2. Instale os pacotes @litertjs/core e @litertjs/tfjs-interop do NPM.
  3. Importe e use o back-end do TensorFlow.js WebGPU. Isso é necessário para que o LiteRT.js funcione com o TensorFlow.js.
  4. Substitua o carregamento do modelo do TensorFlow.js por o carregamento do modelo do LiteRT.js.
  5. Substitua model.predict(inputs) ou model.execute(inputs) do TensorFlow.js por runWithTfjsTensors(liteRtModel, inputs). runWithTfjsTensors usa os mesmos tensores de entrada que os modelos do TensorFlow.js e gera tensores do TensorFlow.js.
  6. Teste se o pipeline do modelo gera os resultados esperados.

O uso do LiteRT.js com runWithTfjsTensors também pode exigir as seguintes mudanças nas entradas do modelo:

  1. Reordenar entradas: dependendo de como o conversor ordenou as entradas e saídas do modelo, talvez seja necessário mudar a ordem delas ao transmiti-las.
  2. Transpor entradas: também é possível que o conversor tenha mudado o layout das entradas e saídas do modelo em comparação com o que o TensorFlow.js usa. Talvez seja necessário transpor as entradas para corresponder ao modelo e as saídas para corresponder ao restante do pipeline.
  3. Renomear entradas: se você estiver usando entradas nomeadas, os nomes também podem ter mudado.

Você pode acessar mais informações sobre as entradas e saídas do modelo com model.getInputDetails() e model.getOutputDetails().