LiteRT für Web mit LiteRT.js

LiteRT.js ist die leistungsstarke WebAI-Laufzeit von Google für Webanwendungen in der Produktion. Es ist eine Fortsetzung des LiteRT-Stacks, die Multi-Framework-Unterstützung bietet und unsere Core-Laufzeitumgebung auf allen Plattformen vereinheitlicht.

LiteRT.js unterstützt die folgenden Kernfunktionen:

  1. Hardwarebeschleunigte Inferenz im Browser: Führen Sie Modelle mit außergewöhnlicher CPU-Leistung aus, die durch XNNPack beschleunigt und auf leichtgewichtige WebAssembly-Module (Wasm) abgebildet werden. Für die Skalierung von GPUs und dedizierter Hardware (z. B. NPUs) werden in LiteRT.js sowohl die WebGPU API als auch die neue WebNN API nativ bereitgestellt, was eine detaillierte plattformspezifische Optimierung ermöglicht.
  2. Kompatibilität mit mehreren Frameworks: Optimieren Sie die Entwicklung durch die native Kompilierung aus Ihrem bevorzugten ML-Framework: PyTorch, JAX oder TensorFlow.
  3. Vorhandene Pipelines iterieren: Die sofort einsatzbereite Integration mit vorhandenen TensorFlow.js-Architekturen erfolgt durch das Parsen von nativ unterstützten TensorFlow.js-Tensoren als direkte Grenzwerteingaben und ‑ausgaben.

Installation

Installieren Sie das Paket @litertjs/core von npm:

npm install @litertjs/core

Die Wasm-Dateien befinden sich in node_modules/@litertjs/core/wasm/. Kopieren und stellen Sie den gesamten wasm/-Ordner bereit. Importieren Sie dann das Paket und laden Sie die Wasm-Dateien:

import {loadLiteRt} from '@litertjs/core';

// Load the LiteRT.js Wasm files from a CDN.
await loadLiteRt('https://cdn.jsdelivr.net/npm/@litertjs/core/wasm/')
// Alternatively, host them from your server.
// They are located in node_modules/@litertjs/core/wasm/
await loadLiteRt(`your/path/to/wasm/`);

Modellkonvertierung

LiteRT.js verwendet dasselbe .tflite-Format wie der Rest des LiteRT-Ökosystems und unterstützt vorhandene Modelle auf Kaggle und Huggingface. Wenn Sie ein neues PyTorch-Modell haben, müssen Sie es konvertieren.

PyTorch-Modell in LiteRT konvertieren

Verwenden Sie den litert-torch-Konverter, um ein PyTorch-Modell in LiteRT zu konvertieren.

import litert_torch

# Load your torch model. We're using resnet for this example.
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1)

sample_inputs = (torch.randn(1, 3, 224, 224),)

# Convert the model to LiteRT.
edge_model = litert_torch.convert(resnet18.eval(), sample_inputs)

# Export the model.
edge_model.export('resnet.tflite')

Konvertiertes Modell ausführen

Nachdem Sie das Modell in eine .tflite-Datei konvertiert haben, können Sie es im Browser ausführen.

import {loadAndCompile} from '@litertjs/core';

// Load the model hosted from your server. This makes an http(s) request.
const model = await loadAndCompile('/path/to/model.tflite', {
    accelerator: 'webgpu',
    // Can select from 'webnn', 'webgpu', & 'wasm'.
    // Additionally, you can pass an array of accelerators e.g. ['webnn', 'wasm']
    // if you would like to fallback to CPU execution,
    // Note that ONLY cpu fallback is supported for now
    // (i.e. specifying ['webnn', 'webgpu']) will lead to compilation errors
});
// The model can also be loaded from a Uint8Array if you want to fetch it yourself.

// Create image input data
const image = new Float32Array(224 * 224 * 3).fill(0);
const inputTensor = new Tensor(image, /* shape */ [1, 3, 224, 224]);

// Run the model
const outputs = await model.run(inputTensor);
// You can also use `await model.run([inputTensor]);`
// or `await model.run({'input_tensor_name': inputTensor});`

// Clean up and get outputs
inputTensor.delete();
const output = outputs[0];
const outputData = await output.data();
output.delete();

In bestehende TensorFlow.js-Pipelines einbinden

Sie sollten LiteRT.js aus den folgenden Gründen in Ihre TensorFlow.js-Pipelines einbinden:

  1. Hervorragende GPU- und Hardwareleistung: LiteRT.js-Modelle nutzen die WebGPU-Beschleunigung für eine optimierte Leistung in verschiedenen Browserarchitekturen. LiteRT.js unterstützt WebGPU und das kommende WebNN und bietet so eine flexible Hardwarebeschleunigung auf einer Vielzahl von Edge-Geräten.
  2. Einfacherer Modellkonvertierungspfad: Der Konvertierungspfad von LiteRT.js führt direkt von PyTorch zu LiteRT. Die Konvertierung von PyTorch zu TensorFlow.js ist wesentlich komplizierter, da Sie von PyTorch -> ONNX -> TensorFlow -> TensorFlow.js wechseln müssen.
  3. Debugging-Tools: Der LiteRT.js-Conversion-Pfad enthält Debugging-Tools.

LiteRT.js ist für die Verwendung in TensorFlow.js-Pipelines konzipiert und mit der Vor- und Nachbearbeitung von TensorFlow.js kompatibel. Sie müssen also nur das Modell selbst migrieren.

So binden Sie LiteRT.js in TensorFlow.js-Pipelines ein:

  1. Konvertieren Sie Ihr ursprüngliches TensorFlow-, JAX- oder PyTorch-Modell in .tflite. Weitere Informationen finden Sie im Abschnitt Modellkonvertierung.
  2. Installieren Sie die NPM-Pakete @litertjs/core und @litertjs/tfjs-interop.
  3. Importieren und verwenden Sie das TensorFlow.js WebGPU-Backend. Dies ist erforderlich, damit LiteRT.js mit TensorFlow.js zusammenarbeiten kann.
  4. Ersetzen Sie TensorFlow.js-Modell laden durch LiteRT.js-Modell laden.
  5. Ersetzen Sie model.predict(inputs) oder model.execute(inputs) von TensorFlow.js durch runWithTfjsTensors(liteRtModel, inputs). runWithTfjsTensors verwendet dieselben Eingabetensoren wie TensorFlow.js-Modelle und gibt TensorFlow.js-Tensoren aus.
  6. Testen Sie, ob die Modellpipeline die erwarteten Ergebnisse ausgibt.

Wenn Sie LiteRT.js mit runWithTfjsTensors verwenden, sind möglicherweise auch die folgenden Änderungen an den Modelleingaben erforderlich:

  1. Eingaben neu anordnen: Je nachdem, wie der Converter die Ein- und Ausgaben des Modells angeordnet hat, müssen Sie möglicherweise ihre Reihenfolge ändern, wenn Sie sie übergeben.
  2. Eingaben transponieren: Möglicherweise hat der Konverter das Layout der Ein- und Ausgaben des Modells im Vergleich zu TensorFlow.js geändert. Möglicherweise müssen Sie Ihre Eingaben transponieren, damit sie dem Modell entsprechen, und Ihre Ausgaben, damit sie dem Rest der Pipeline entsprechen.
  3. Eingaben umbenennen: Wenn Sie benannte Eingaben verwenden, haben sich möglicherweise auch die Namen geändert.

Mit model.getInputDetails() und model.getOutputDetails() können Sie weitere Informationen zu den Ein- und Ausgaben des Modells abrufen.