借助 LLM Inference API,您可以完全在设备端为 Android 应用运行大语言模型 (LLM),并使用这些模型执行各种任务,例如生成文本、以自然语言形式检索信息以及汇总文档。该任务内置了对多个文本到文本大型语言模型的支持,因此您可以将最新的设备端生成式 AI 模型应用于 Android 应用。
此任务支持 Gemma 的以下变体:Gemma-2 2B、Gemma 2B 和 Gemma 7B。Gemma 是一系列先进的轻量级开放式模型,其开发采用了与 Gemini 模型相同的研究成果和技术。它还支持以下外部模型:Phi-2、Falcon-RW-1B 和 StableLM-3B。
除了受支持的模型之外,用户还可以使用 Google 的 AI Edge Torch 将 PyTorch 模型导出为多签名 LiteRT (tflite
) 模型,这些模型会与分词器参数捆绑在一起,以创建与 LLM 推理 API 兼容的任务软件包。
您可以通过 MediaPipe Studio 演示查看此任务的实际运作方式。如需详细了解此任务的功能、模型和配置选项,请参阅概览。
代码示例
本指南将介绍一个 Android 基本文本生成应用示例。您可以将该应用用作您自己的 Android 应用的起点,也可以在修改现有应用时参考该应用。示例代码托管在 GitHub 上。
下载代码
以下说明介绍了如何使用 git 命令行工具创建示例代码的本地副本。
如需下载示例代码,请执行以下操作:
- 使用以下命令克隆 Git 代码库:
git clone https://github.com/google-ai-edge/mediapipe-samples
- (可选)将您的 Git 实例配置为使用稀疏检出,以便您只保留 LLM Inference API 示例应用的文件:
cd mediapipe git sparse-checkout init --cone git sparse-checkout set examples/llm_inference/android
创建示例代码的本地版本后,您可以将项目导入 Android Studio 并运行应用。如需了解相关说明,请参阅 Android 设置指南。
设置
本部分介绍了专门用于设置开发环境和代码项目以使用 LLM Inference API 的关键步骤。如需了解如何设置开发环境以使用 MediaPipe 任务(包括平台版本要求)的一般信息,请参阅 Android 设置指南。
依赖项
LLM Inference API 使用 com.google.mediapipe:tasks-genai
库。将以下依赖项添加到 Android 应用的 build.gradle
文件中:
dependencies {
implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}
对于搭载 Android 12(API 31)或更高版本的设备,请添加原生 OpenCL 库依赖项。如需了解详情,请参阅有关 uses-native-library
标记的文档。
将以下 uses-native-library
标记添加到 AndroidManifest.xml
文件中:
<uses-native-library android:name="libOpenCL.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-car.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false"/>
型号
MediaPipe LLM Inference API 需要与此任务兼容的经过训练的文本到文本语言模型。下载模型后,请安装所需的依赖项,并将模型推送到 Android 设备。如果您使用的是 Gemma 以外的模型,则必须将该模型转换为与 MediaPipe 兼容的格式。
如需详细了解 LLM Inference API 适用的已训练模型,请参阅任务概览的“模型”部分。
下载模型
在初始化 LLM Inference API 之前,请下载某个受支持的模型,并将该文件存储在项目目录中:
- Gemma-2 2B:最新版本的 Gemma 系列模型。是一系列先进的轻量级开放模型的一部分,这些模型采用与 Gemini 模型相同的研究成果和技术构建而成。
- Gemma 2B:是一系列先进的轻量级开放式模型的一部分,其开发采用了与 Gemini 模型相同的研究成果和技术。非常适合用于处理各种文本生成任务,包括问答、摘要和推理。
- Phi-2:一个拥有 27 亿参数的 Transformer 模型,最适合问答、聊天和代码格式。
- Falcon-RW-1B:一个参数数为 10 亿的仅解码器因果模型,基于 RefinedWeb 的 3500 亿个词元进行训练。
- StableLM-3B:一个拥有 30 亿参数的 decoder-only 语言模型,基于多样化的英语和代码数据集内的 1 万亿个词元进行了预训练。
除了受支持的模型之外,您还可以使用 Google 的 AI Edge Torch 将 PyTorch 模型导出为多签名 LiteRT (tflite
) 模型。如需了解详情,请参阅适用于 PyTorch 模型的 Torch 生成式转换器。
我们建议使用 Gemma-2 2B,您可以在 Kaggle Models 上找到该模型。如需详细了解其他可用模型,请参阅任务概览的“模型”部分。
将模型转换为 MediaPipe 格式
LLM Inference API 与两类模型兼容,其中一些模型需要转换。使用下表确定您的模型所需的步骤方法。
模型 | 转化方法 | 兼容的平台 | 文件类型 | |
---|---|---|---|---|
支持的模型 | Gemma 2B、Gemma 7B、Gemma-2 2B、Phi-2、StableLM、Falcon | MediaPipe | Android、iOS、网站 | .bin |
其他 PyTorch 模型 | 所有 PyTorch LLM 模型 | AI Edge Torch Generative 库 | Android、iOS | .task |
我们在 Kaggle 上托管了 Gemma 2B、Gemma 7B 和 Gemma-2 2B 的转换后的 .bin
文件。这些模型可以使用我们的 LLM 推理 API 直接部署。如需了解如何转换其他模型,请参阅模型转换部分。
将模型推送到设备
将 output_path 文件夹的内容推送到 Android 设备。
$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin
创建任务
MediaPipe LLM Inference API 使用 createFromOptions()
函数来设置任务。createFromOptions()
函数接受配置选项的值。如需详细了解配置选项,请参阅配置选项。
以下代码使用基本配置选项初始化任务:
// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
.setModelPATH('/data/local/.../')
.setMaxTokens(1000)
.setTopK(40)
.setTemperature(0.8)
.setRandomSeed(101)
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)
配置选项
使用以下配置选项设置 Android 应用:
选项名称 | 说明 | 值范围 | 默认值 |
---|---|---|---|
modelPath |
模型在项目目录中的存储路径。 | 路径 | 不适用 |
maxTokens |
模型处理的词元(输入词元 + 输出词元)数量上限。 | 整数 | 512 |
topK |
模型在生成的每个步骤中考虑的令牌数。 将预测限制为前 k 个概率最高的 token。 | 整数 | 40 |
temperature |
生成期间引入的随机性程度。温度越高,生成的文本越富有创造力,温度越低,生成的文本越具可预测性。 | 浮点数 | 0.8 |
randomSeed |
文本生成期间使用的随机种子。 | 整数 | 0 |
loraPath |
设备本地 LoRA 模型的绝对路径。注意:此功能仅适用于 GPU 型号。 | 路径 | 不适用 |
resultListener |
设置结果监听器以异步接收结果。 仅在使用异步生成方法时适用。 | 不适用 | 不适用 |
errorListener |
设置可选的错误监听器。 | 不适用 | 不适用 |
准备数据
LLM Inference API 接受以下输入:
- prompt(字符串):问题或提示。
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."
运行任务
使用 generateResponse()
方法针对上一部分 (inputPrompt
) 中提供的输入文本生成文本响应。这会生成单个生成的响应。
val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")
如需流式传输响应,请使用 generateResponseAsync()
方法。
val options = LlmInference.LlmInferenceOptions.builder()
...
.setResultListener { partialResult, done ->
logger.atInfo().log("partial result: $partialResult")
}
.build()
llmInference.generateResponseAsync(inputPrompt)
处理和显示结果
LLM 推理 API 会返回一个 LlmInferenceResult
,其中包含生成的回答文本。
Here's a draft you can use:
Subject: Lunch on Saturday Reminder
Hi Brett,
Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.
Looking forward to it!
Best,
[Your Name]
LoRA 模型自定义
Mediapipe LLM Inference API 可配置为支持大语言模型的低秩自适应 (LoRA)。利用经过微调的 LoRA 模型,开发者可以通过经济高效的训练流程自定义 LLM 的行为。
LLM Inference API 的 LoRA 支持适用于 GPU 后端的所有 Gemma 变体和 Phi-2 模型,LoRA 权重仅适用于注意力层。此初始实现将作为实验性 API 用于未来开发,我们计划在即将发布的更新中支持更多模型和各种类型的图层。
准备 LoRA 模型
按照 HuggingFace 上的说明,使用支持的模型类型(Gemma 或 Phi-2)在您自己的数据集上训练经过微调的 LoRA 模型。Gemma-2 2B、Gemma 2B 和 Phi-2 模型均以 safetensors 格式在 HuggingFace 上提供。由于 LLM Inference API 仅支持注意力层上的 LoRA,因此在创建 LoraConfig
时,请仅指定注意力层,如下所示:
# For Gemma
from peft import LoraConfig
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# For Phi-2
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)
如需进行测试,您可以使用 HuggingFace 上提供的适用于 LLM 推理 API 的公开可用的微调 LoRA 模型。例如,对于 Gemma-2B,使用 monsterapi/gemma-2b-lora-maths-orca-200k;对于 Phi-2,使用 lole25/phi-2-sft-ultrachat-lora。
使用准备好的训练数据集进行训练并保存模型后,您会获得一个包含经过微调的 LoRA 模型权重的 adapter_model.safetensors
文件。safetensors 文件是模型转换中使用的 LoRA 检查点。
下一步,您需要使用 MediaPipe Python 软件包将模型权重转换为 TensorFlow Lite Flatbuffer。ConversionConfig
应指定基本模型选项以及其他 LoRA 选项。请注意,由于该 API 仅支持使用 GPU 进行 LoRA 推理,因此后端必须设置为 'gpu'
。
import mediapipe as mp
from mediapipe.tasks.python.genai import converter
config = converter.ConversionConfig(
# Other params related to base model
...
# Must use gpu backend for LoRA conversion
backend='gpu',
# LoRA related params
lora_ckpt=LORA_CKPT,
lora_rank=LORA_RANK,
lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)
converter.convert_checkpoint(config)
转换器将输出两个 TFLite FlatBuffer 文件,一个用于基准模型,另一个用于 LoRA 模型。
LoRA 模型推理
Web、Android 和 iOS LLM 推理 API 已更新为支持 LoRA 模型推理。
Android 在初始化期间支持静态 LoRA。如需加载 LoRA 模型,用户需要指定 LoRA 模型路径以及基础 LLM。// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
.setModelPath('<path to base model>')
.setMaxTokens(1000)
.setTopK(40)
.setTemperature(0.8)
.setRandomSeed(101)
.setLoraPath('<path to LoRA model>')
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)
如需使用 LoRA 运行 LLM 推理,请使用与基准模型相同的 generateResponse()
或 generateResponseAsync()
方法。