适用于 Web 的 LLM 推理指南

借助 LLM Inference API,您可以完全在设备端为 Web 应用运行大语言模型 (LLM),这些模型可用于执行各种任务,例如生成文本、以自然语言形式检索信息以及总结文档。该任务内置对多个文本到文本大型语言模型的支持,因此您可以将最新的设备端生成式 AI 模型应用于 Web 应用。

如需快速将 LLM Inference API 添加到您的 Web 应用,请按照快速入门中的说明操作。如需查看运行 LLM Inference API 的 Web 应用的基本示例,请参阅示例应用。如需更深入地了解 LLM Inference API 的运作方式,请参阅配置选项模型转换LoRA 调优部分。

您可以通过 MediaPipe Studio 演示查看此任务的实际运作方式。如需详细了解此任务的功能、模型和配置选项,请参阅概览

快速入门

请按以下步骤将 LLM Inference API 添加到您的 Web 应用。LLM Inference API 需要与 WebGPU 兼容的网络浏览器。如需查看兼容的浏览器的完整列表,请参阅 GPU 浏览器兼容性

添加依赖项

LLM Inference API 使用 @mediapipe/tasks-genai 软件包。

安装本地暂存区所需的软件包:

npm install @mediapipe/tasks-genai

如需部署到服务器,请使用 jsDelivr 等内容分发网络 (CDN) 服务直接将代码添加到 HTML 页面中:

<head>
  <script src="https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/genai_bundle.cjs"
    crossorigin="anonymous"></script>
</head>

下载模型

Kaggle 模型下载采用 8 位量化格式的 Gemma-2 2B。如需详细了解可用的模型,请参阅“模型”文档

将模型存储在项目目录中:

<dev-project-root>/assets/gemma-2b-it-gpu-int8.bin

使用 baseOptions 对象 modelAssetPath 形参指定模型的路径:

baseOptions: { modelAssetPath: `/assets/gemma-2b-it-gpu-int8.bin`}

初始化任务

使用基本配置选项初始化任务:

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
llmInference = await LlmInference.createFromOptions(genai, {
    baseOptions: {
        modelAssetPath: '/assets/gemma-2b-it-gpu-int8.bin'
    },
    maxTokens: 1000,
    topK: 40,
    temperature: 0.8,
    randomSeed: 101
});

运行任务

使用 generateResponse() 函数触发推理。

const response = await llmInference.generateResponse(inputPrompt);
document.getElementById('output').textContent = response;

如需流式传输响应,请使用以下代码:

llmInference.generateResponse(
  inputPrompt,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});

示例应用

此示例应用是一个使用 LLM Inference API 的 Web 基本文本生成应用。您可以将该应用用作您自己的 Web 应用的起点,也可以在修改现有应用时参考该应用。示例代码托管在 GitHub 上。

使用以下命令克隆 Git 代码库:

git clone https://github.com/google-ai-edge/mediapipe-samples

如需了解详情,请参阅适用于网站的设置指南

配置选项

使用以下配置选项设置 Web 应用:

选项名称 说明 值范围 默认值
modelPath 模型在项目目录中的存储路径。 路径 不适用
maxTokens 模型处理的词元(输入词元 + 输出词元)数量上限。 整数 512
topK 模型在生成过程中每个步骤考虑的令牌数。 将预测限制为前 k 个概率最高的 token。 整数 40
temperature 生成过程中引入的随机性程度。温度越高,生成的文本就越富有创造力;温度越低,生成的文本就越具可预测性。 浮点数 0.8
randomSeed 文本生成期间使用的随机种子。 整数 0
loraRanks LoRA 模型在运行时要使用的 LoRA 排名。注意:此功能仅适用于 GPU 型号。 整数数组 不适用

模型转换

LLM Inference API 与以下类型的模型兼容,其中一些模型需要进行模型转换。使用下表确定您的模型所需的步骤方法。

模型 转化方法 兼容的平台 文件类型
Gemma-3 1B 无需转换 Android、网络 .task
Gemma 2B、Gemma 7B、Gemma-2 2B 无需转换 Android、iOS、网页 .bin
Phi-2、StableLM、Falcon MediaPipe 转换脚本 Android、iOS、网页 .bin
所有 PyTorch LLM 模型 AI Edge Torch Generative 库 Android、iOS .task

如需了解如何转换其他模型,请参阅模型转换部分。

LoRA 自定义

LLM Inference API 支持使用 PEFT(参数高效微调)库进行 LoRA(低秩自适应)调优。LoRA 调优通过经济高效的训练流程自定义 LLM 的行为,根据新训练数据创建一小组可训练权重,而不是重新训练整个模型。

LLM Inference API 支持向 Gemma-2 2BGemma 2BPhi-2 模型的注意力层添加 LoRA 权重。下载 safetensors 格式的模型。

基本模型必须采用 safetensors 格式,才能创建 LoRA 权重。完成 LoRA 训练后,您可以将模型转换为 FlatBuffers 格式,以便在 MediaPipe 上运行。

准备 LoRA 权重

使用 PEFT 中的 LoRA 方法指南,基于您自己的数据集训练经过微调的 LoRA 模型。

LLM Inference API 仅支持在注意力层上使用 LoRA,因此请仅在 LoraConfig 中指定注意力层:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

使用准备好的数据集进行训练并保存模型后,adapter_model.safetensors 中会提供经过微调的 LoRA 模型权重。safetensors 文件是模型转换期间使用的 LoRA 检查点。

模型转换

使用 MediaPipe Python 软件包将模型权重转换为 Flatbuffer 格式。ConversionConfig 指定了基本模型选项以及其他 LoRA 选项。

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_FILE,
)

converter.convert_checkpoint(config)

转换器将生成两个与 MediaPipe 兼容的文件,一个用于基准模型,另一个用于 LoRA 模型。

LoRA 模型推理

Web 在运行时支持动态 LoRA,这意味着用户在初始化期间声明 LoRA 秩。这意味着,您可以在运行时切换不同的 LoRA 模型。

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
const llmInference = await LlmInference.createFromOptions(genai, {
    // options for the base model
    ...
    // LoRA ranks to be used by the LoRA models during runtime
    loraRanks: [4, 8, 16]
});

在初始化基础模型后,在运行时加载 LoRA 模型。在生成 LLM 回答时,通过传递模型引用触发 LoRA 模型。

// Load several LoRA models. The returned LoRA model reference is used to specify
// which LoRA model to be used for inference.
loraModelRank4 = await llmInference.loadLoraModel(loraModelRank4Url);
loraModelRank8 = await llmInference.loadLoraModel(loraModelRank8Url);

// Specify LoRA model to be used during inference
llmInference.generateResponse(
  inputPrompt,
  loraModelRank4,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});