LiteRT.js を使ってみる

これは、PyTorch モデルを変換して WebGPU アクセラレーションでブラウザで実行するプロセスを網羅した LiteRT.js のエンドツーエンド ガイドです。この例では、ビジョン モデルに ResNet18 を使用し、前処理と後処理に TensorFlow.js を使用します。

このガイドでは、次の手順について説明します。

  1. AI Edge Torch を使用して、PyTorch モデルを LiteRT に変換します。
    1. LiteRT パッケージをウェブアプリに追加します。
  2. モデルを読み込みます。
  3. 前処理と後処理のロジックを記述します。

LiteRT に変換する

PyTorch Converter ノートブックを使用して、PyTorch モデルを適切な .tflite 形式に変換します。発生する可能性のあるエラーの種類とその修正方法の詳細なガイドについては、AI Edge Torch Converter の README をご覧ください。

モデルは torch.export.export と互換性がある必要があります。つまり、TorchDynamo でエクスポート可能である必要があります。したがって、テンソル内のランタイム値に依存する Python の条件分岐を含めることはできません。torch.export.export の実行中に次のエラーが表示された場合、モデルは torch.export.export でエクスポートできません。また、モデルのテンソルに動的な入力ディメンションや出力ディメンションを含めることはできません。これにはバッチ ディメンションが含まれます。

TensorRT 互換または ONNX エクスポート可能な PyTorch モデルから始めることもできます。

  • モデルの TensorRT 互換バージョンは、一部のタイプの TensorRT 変換ではモデルが TorchDynamo エクスポート可能である必要があるため、良い出発点となります。モデルで NVIDIA / CUDA オペレーションを使用する場合は、標準の PyTorch オペレーションに置き換える必要があります。

  • ONNX エクスポート可能な PyTorch モデルは、良い出発点となります。ただし、一部の ONNX モデルでは、TorchDynamo ではなく TorchScript を使用してエクスポートします。この場合、モデルは TorchDynamo エクスポート可能ではない可能性があります(ただし、元のモデルコードよりも近い可能性はあります)。

詳細については、PyTorch モデルを LiteRT に変換するをご覧ください。

LiteRT パッケージを追加する

npm から @litertjs/core パッケージをインストールします。

npm install @litertjs/core

パッケージをインポートして、その Wasm ファイルを読み込みます。

import {loadLiteRt} from '@litertjs/core';

// They are located in node_modules/@litertjs/core/wasm/
// Serve them statically on your server.
await loadLiteRt(`your/path/to/wasm/`);

モデルを読み込む

LiteRT.js と LiteRT-TFJS 変換ユーティリティをインポートして初期化します。また、LiteRT.js にテンソルを渡すために TensorFlow.js をインポートする必要があります。

import {CompileOptions, loadAndCompile, loadLiteRt, setWebGpuDevice} from '@litertjs/core';
import {runWithTfjsTensors} from '@litertjs/tfjs-interop';

// TensorFlow.js imports
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgpu'; // Only WebGPU is supported
import {WebGPUBackend} from '@tensorflow/tfjs-backend-webgpu';

async function main() {
  // Initialize TensorFlow.js WebGPU backend
  await tf.setBackend('webgpu');

  // Initialize LiteRT.js's Wasm files
  await loadLiteRt('your/path/to/wasm/');

  // Make LiteRt use the same GPU device as TFJS (for tensor conversion)
  const backend = tf.backend() as WebGPUBackend;
  setWebGpuDevice(backend.device);
  // ...
}

main();

変換された LiteRT モデルを読み込みます。

const model = await loadAndCompile('path_to_model.tflite', {
  accelerator: 'webgpu', // or 'wasm'
});

モデル パイプラインを作成する

モデルをアプリに接続する前処理と後処理のロジックを記述します。前処理と後処理には TensorFlow.js を使用することをおすすめしますが、TensorFlow.js で記述されていない場合は、await tensor.data を呼び出して値を ArrayBuffer として取得するか、await tensor.array を呼び出して構造化された JS 配列を取得できます。

次に、ResNet18 のエンドツーエンド パイプラインの例を示します。

// Wrap in a tf.tidy call to automatically clean up intermediate TensorFlow.js tensors.
// (Note: tidy only supports synchronous functions).
const top5 = tf.tidy(() => {
  // Get RGB data values from an image element and convert it to range [0, 1).
  const image = tf.browser.fromPixels(dogs, 3).div(255);

  // These preprocessing steps come from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L315
  // The mean and standard deviation for the image normalization come from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_presets.py#L38
  const imageData = image.resizeBilinear([224, 224])
    .sub([0.485, 0.456, 0.406])
    .div([0.229, 0.224, 0.225])
    .reshape([1, 224, 224, 3])
    .transpose([0, 3, 1, 2]);

  // You can pass inputs as a single tensor, an array, or a JS Object
  // where keys are the tensor names in the TFLite model.
  // When passing an Object, the output is also an Object.
  // Here, we're passing a single tensor, so the output is an array.
  const probabilities = runWithTfjsTensors(model, imageData)[0];

  // Get the top five classes.
  return tf.topk(probabilities, 5);
});

const values = await top5.values.data();
const indices = await top5.indices.data();
top5.values.dispose(); // Clean up the tfjs tensors.
top5.indices.dispose();

// Print the top five classes.
const classes = ... // Class names are loaded from a JSON file in the demo.
for (let i = 0; i < 5; ++i) {
  const text = `${classes[indices[i]]}: ${values[i]}`;
  console.log(text);
}

テストとトラブルシューティング

アプリケーションをテストしてエラーを処理する方法については、次のセクションをご覧ください。

疑似入力によるテスト

モデルを読み込んだら、まず偽の入力でモデルをテストすることをおすすめします。これにより、モデル パイプラインの前処理ロジックと後処理ロジックの作成に時間を費やす前に、実行時エラーをキャッチできます。これを確認するには、LiteRT.js モデル テスターを使用するか、手動でテストします。

LiteRT.js モデル テスター

LiteRT.js モデル テスターは、ランダムな入力を使用して GPU と CPU でモデルを実行し、モデルが GPU で正しく実行されることを確認します。以下の点が確認されます。

  • 入力データ型と出力データ型がサポートされているかどうか。
  • すべてのオペレーションが GPU で利用可能かどうか。
  • GPU 出力が参照 CPU 出力とどの程度一致しているか。
  • GPU 推論のパフォーマンス。

LiteRT.js モデル テスターを実行するには、npm i @litertjs/model-tester を実行してから npx model-tester を実行します。モデルを実行するためのブラウザタブが開きます。

手動モデルテスト

LiteRT.js モデル テスター(@litertjs/model-tester)を使用する代わりに、モデルを手動でテストする場合は、偽の入力を生成して runWithTfjsTensors でモデルを実行できます。

フェイク入力を生成するには、入力テンソルの名前と形状を知る必要があります。これらは、LiteRT.js で model.getInputDetails または model.getOutputDetails を呼び出すことで見つけることができます。それらを見つける簡単な方法は、モデルの作成後にブレークポイントを設定することです。または、モデル エクスプローラを使用します。

入力と出力の形状と名前がわかったら、偽の入力でモデルをテストできます。これにより、残りの ML パイプラインを記述する前にモデルが実行されることが確実になります。これにより、すべてのモデル オペレーションがサポートされていることがテストされます。次に例を示します。

// Imports, initialization, and model loading...
// Create fake inputs for the model
const fakeInputs = model.getInputDetails().map(
    ({shape, dtype}) => tf.ones(shape, dtype));

// Run the model
const outputs = runWithTfjsTensors(model, fakeInputs);
console.log(outputs);

エラータイプ

一部の LiteRT モデルは LiteRT.js でサポートされていない場合があります。エラーは通常、次のカテゴリに分類されます。

  • Shape Mismatch: GPU にのみ影響する既知のバグ。
  • Operation Not Supported: ランタイムがモデル内のオペレーションをサポートしていません。WebGPU バックエンドは CPU よりもカバレッジが限られているため、GPU でこのエラーが表示される場合は、代わりに CPU でモデルを実行できる可能性があります。
  • サポートされていないテンソル型: LiteRT.js は、モデルの入出力で int32 テンソルと float32 テンソルのみをサポートします。
  • モデルが大きすぎる: LiteRT.js で読み込めるモデルのサイズには上限があります。

Operation Not Supported

これは、使用されているバックエンドがモデル内のオペレーションの 1 つをサポートしていないことを示します。このオペレーションを回避するために元の PyTorch モデルを書き換えて再変換するか、CPU でモデルを実行する必要があります。

BROADCAST_TO の場合、モデルへのすべての入力テンソルのバッチ ディメンションを同じにすることで解決できます。他のケースでは、より複雑になる可能性があります。

サポートされていないテンソル型

LiteRT.js は、モデルの入力と出力に int32 テンソルと float32 テンソルのみをサポートしています。

モデルが大きすぎる

通常、これは Aborted() の呼び出し、またはモデル読み込み時のメモリ割り当ての失敗として現れます。LiteRT.js で読み込めるモデルのサイズには上限があるため、このエラーが表示された場合は、モデルが大きすぎる可能性があります。ai-edge-quantizer を使用して重みを量子化できますが、計算は float32 または float16 のままにし、モデルの入力と出力は float32 または int32 のままにします。