LiteRT.js는 프로덕션 Web 애플리케이션을 타겟팅하는 Google의 고성능 WebAI 런타임입니다. LiteRT 스택의 연속으로, 다중 프레임워크 지원을 보장하고 모든 플랫폼에서 핵심 런타임을 통합합니다.
LiteRT.js는 다음 핵심 기능을 지원합니다.
- 브라우저 내 하드웨어 가속 추론: 경량 WebAssembly (Wasm)에 매핑된 XNNPack으로 가속화된 뛰어난 CPU 성능으로 모델을 실행합니다. GPU 및 전용 하드웨어 확장 (예: NPU)의 경우 LiteRT.js는 WebGPU API와 새롭게 등장하는 WebNN API를 모두 기본적으로 표시하여 세부적인 플랫폼별 최적화를 지원합니다.
- 다중 프레임워크 호환성: 선호하는 ML 프레임워크(PyTorch, JAX, TensorFlow)에서 네이티브로 컴파일하여 개발 의미 체계를 간소화합니다.
- 기존 파이프라인 반복: 기본적으로 지원되는 TensorFlow.js 텐서를 직접 경계 입력 및 출력으로 파싱하여 기존 TensorFlow.js 아키텍처와 즉시 통합합니다.
설치
npm에서 @litertjs/core 패키지를 설치합니다.
npm install @litertjs/core
Wasm 파일은 node_modules/@litertjs/core/wasm/에 있습니다. 편의를 위해 전체 wasm/ 폴더를 복사하여 제공합니다. 그런 다음 패키지를 가져오고 Wasm 파일을 로드합니다.
import {loadLiteRt} from '@litertjs/core';
// Load the LiteRT.js Wasm files from a CDN.
await loadLiteRt('https://cdn.jsdelivr.net/npm/@litertjs/core/wasm/')
// Alternatively, host them from your server.
// They are located in node_modules/@litertjs/core/wasm/
await loadLiteRt(`your/path/to/wasm/`);
모델 변환
LiteRT.js는 나머지 LiteRT 생태계와 동일한 .tflite 형식을 사용하며 Kaggle 및 Huggingface의 기존 모델을 지원합니다. 새 PyTorch 모델이 있는 경우 변환해야 합니다.
PyTorch 모델을 LiteRT로 변환
PyTorch 모델을 LiteRT로 변환하려면 litert-torch 변환기를 사용하세요.
import litert_torch
# Load your torch model. We're using resnet for this example.
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
sample_inputs = (torch.randn(1, 3, 224, 224),)
# Convert the model to LiteRT.
edge_model = litert_torch.convert(resnet18.eval(), sample_inputs)
# Export the model.
edge_model.export('resnet.tflite')
변환된 모델 실행
모델을 .tflite 파일로 변환한 후 브라우저에서 실행할 수 있습니다.
import {loadAndCompile} from '@litertjs/core';
// Load the model hosted from your server. This makes an http(s) request.
const model = await loadAndCompile('/path/to/model.tflite', {
accelerator: 'webgpu',
// Can select from 'webnn', 'webgpu', & 'wasm'.
// Additionally, you can pass an array of accelerators e.g. ['webnn', 'wasm']
// if you would like to fallback to CPU execution,
// Note that ONLY cpu fallback is supported for now
// (i.e. specifying ['webnn', 'webgpu']) will lead to compilation errors
});
// The model can also be loaded from a Uint8Array if you want to fetch it yourself.
// Create image input data
const image = new Float32Array(224 * 224 * 3).fill(0);
const inputTensor = new Tensor(image, /* shape */ [1, 3, 224, 224]);
// Run the model
const outputs = await model.run(inputTensor);
// You can also use `await model.run([inputTensor]);`
// or `await model.run({'input_tensor_name': inputTensor});`
// Clean up and get outputs
inputTensor.delete();
const output = outputs[0];
const outputData = await output.data();
output.delete();
기존 TensorFlow.js 파이프라인에 통합
다음과 같은 이유로 TensorFlow.js 파이프라인에 LiteRT.js를 통합하는 것이 좋습니다.
- 뛰어난 GPU 및 하드웨어 성능: LiteRT.js 모델은 WebGPU 가속을 활용하여 브라우저 아키텍처 전반에서 성능을 최적화합니다. LiteRT.js는 WebGPU 및 곧 출시될 WebNN을 지원하여 다양한 에지 기기에서 유연한 하드웨어 가속을 제공합니다.
- 더 쉬운 모델 변환 경로: LiteRT.js 변환 경로는 PyTorch에서 LiteRT로 직접 이동합니다. PyTorch에서 TensorFlow.js로의 변환 경로는 훨씬 더 복잡하며 PyTorch -> ONNX -> TensorFlow -> TensorFlow.js로 이동해야 합니다.
- 디버깅 도구: LiteRT.js 변환 경로에는 디버깅 도구가 함께 제공됩니다.
LiteRT.js는 TensorFlow.js 파이프라인 내에서 작동하도록 설계되었으며 TensorFlow.js 사전 처리 및 사후 처리와 호환되므로 모델 자체만 이전하면 됩니다.
다음 단계에 따라 LiteRT.js를 TensorFlow.js 파이프라인에 통합하세요.
- 원래 TensorFlow, JAX 또는 PyTorch 모델을
.tflite로 변환합니다. 자세한 내용은 모델 변환 섹션을 참고하세요. @litertjs/core및@litertjs/tfjs-interopNPM 패키지를 설치합니다.- TensorFlow.js WebGPU 백엔드를 가져와 사용합니다. LiteRT.js가 TensorFlow.js와 상호 운용하려면 이 작업이 필요합니다.
- TensorFlow.js 모델 로드를 LiteRT.js 모델 로드로 바꿉니다.
- TensorFlow.js
model.predict(inputs) 또는model.execute(inputs)를runWithTfjsTensors(liteRtModel, inputs)로 대체합니다.runWithTfjsTensors는 TensorFlow.js 모델에서 사용하는 것과 동일한 입력 텐서를 사용하고 TensorFlow.js 텐서를 출력합니다. - 모델 파이프라인이 예상한 결과를 출력하는지 테스트합니다.
runWithTfjsTensors와 함께 LiteRT.js를 사용하는 경우 모델 입력에 다음과 같은 변경사항이 필요할 수도 있습니다.
- 입력 재정렬: 변환기가 모델의 입력과 출력을 정렬한 방식에 따라 입력과 출력을 전달할 때 순서를 변경해야 할 수 있습니다.
- 입력 전치: 변환기가 TensorFlow.js에서 사용하는 것과 비교하여 모델의 입력 및 출력 레이아웃을 변경했을 수도 있습니다. 모델과 일치하도록 입력을 전치하고 파이프라인의 나머지 부분과 일치하도록 출력을 전치해야 할 수 있습니다.
- 입력 이름 바꾸기: 이름이 지정된 입력을 사용하는 경우 이름도 변경되었을 수 있습니다.
model.getInputDetails() 및 model.getOutputDetails()을 사용하여 모델의 입력 및 출력에 관한 자세한 정보를 확인할 수 있습니다.