适用于 Android 的 LLM 推断指南

借助 LLM Inference API,您可以完全在设备端为 Android 应用运行大语言模型 (LLM),并使用这些模型执行各种任务,例如生成文本、以自然语言形式检索信息以及汇总文档。该任务内置了对多个文本到文本大型语言模型的支持,因此您可以将最新的设备端生成式 AI 模型应用于 Android 应用。

此任务支持 Gemma 的以下变体:Gemma-2 2B、Gemma 2B 和 Gemma 7B。Gemma 是一系列先进的轻量级开放式模型,其开发采用了与 Gemini 模型相同的研究成果和技术。它还支持以下外部模型:Phi-2Falcon-RW-1BStableLM-3B

除了受支持的模型之外,用户还可以使用 Google 的 AI Edge Torch 将 PyTorch 模型导出为多签名 LiteRT (tflite) 模型,这些模型会与分词器参数捆绑在一起,以创建与 LLM 推理 API 兼容的任务软件包。

您可以通过 MediaPipe Studio 演示查看此任务的实际运作方式。如需详细了解此任务的功能、模型和配置选项,请参阅概览

代码示例

本指南将介绍一个 Android 基本文本生成应用示例。您可以将该应用用作您自己的 Android 应用的起点,也可以在修改现有应用时参考该应用。示例代码托管在 GitHub 上。

下载代码

以下说明介绍了如何使用 git 命令行工具创建示例代码的本地副本。

如需下载示例代码,请执行以下操作:

  1. 使用以下命令克隆 Git 代码库:
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. (可选)将您的 Git 实例配置为使用稀疏检出,以便您只保留 LLM Inference API 示例应用的文件:
    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/android
    

创建示例代码的本地版本后,您可以将项目导入 Android Studio 并运行应用。如需了解相关说明,请参阅 Android 设置指南

设置

本部分介绍了专门用于设置开发环境和代码项目以使用 LLM Inference API 的关键步骤。如需了解如何设置开发环境以使用 MediaPipe 任务(包括平台版本要求)的一般信息,请参阅 Android 设置指南

依赖项

LLM Inference API 使用 com.google.mediapipe:tasks-genai 库。将以下依赖项添加到 Android 应用的 build.gradle 文件中:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}

对于搭载 Android 12(API 31)或更高版本的设备,请添加原生 OpenCL 库依赖项。如需了解详情,请参阅有关 uses-native-library 标记的文档。

将以下 uses-native-library 标记添加到 AndroidManifest.xml 文件中:

<uses-native-library android:name="libOpenCL.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-car.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false"/>

型号

MediaPipe LLM Inference API 需要与此任务兼容的经过训练的文本到文本语言模型。下载模型后,请安装所需的依赖项,并将模型推送到 Android 设备。如果您使用的是 Gemma 以外的模型,则必须将该模型转换为与 MediaPipe 兼容的格式。

如需详细了解 LLM Inference API 适用的已训练模型,请参阅任务概览的“模型”部分

下载模型

在初始化 LLM Inference API 之前,请下载某个受支持的模型,并将该文件存储在项目目录中:

  • Gemma-2 2B:最新版本的 Gemma 系列模型。是一系列先进的轻量级开放模型的一部分,这些模型采用与 Gemini 模型相同的研究成果和技术构建而成。
  • Gemma 2B:是一系列先进的轻量级开放式模型的一部分,其开发采用了与 Gemini 模型相同的研究成果和技术。非常适合用于处理各种文本生成任务,包括问答、摘要和推理。
  • Phi-2:一个拥有 27 亿参数的 Transformer 模型,最适合问答、聊天和代码格式。
  • Falcon-RW-1B:一个参数数为 10 亿的仅解码器因果模型,基于 RefinedWeb 的 3500 亿个词元进行训练。
  • StableLM-3B:一个拥有 30 亿参数的 decoder-only 语言模型,基于多样化的英语和代码数据集内的 1 万亿个词元进行了预训练。

除了受支持的模型之外,您还可以使用 Google 的 AI Edge Torch 将 PyTorch 模型导出为多签名 LiteRT (tflite) 模型。如需了解详情,请参阅适用于 PyTorch 模型的 Torch 生成式转换器

我们建议使用 Gemma-2 2B,您可以在 Kaggle Models 上找到该模型。如需详细了解其他可用模型,请参阅任务概览的“模型”部分

将模型转换为 MediaPipe 格式

LLM Inference API 与两类模型兼容,其中一些模型需要转换。使用下表确定您的模型所需的步骤方法。

模型 转化方法 兼容的平台 文件类型
支持的模型 Gemma 2B、Gemma 7B、Gemma-2 2B、Phi-2、StableLM、Falcon MediaPipe Android、iOS、网站 .bin
其他 PyTorch 模型 所有 PyTorch LLM 模型 AI Edge Torch Generative 库 Android、iOS .task

我们在 Kaggle 上托管了 Gemma 2B、Gemma 7B 和 Gemma-2 2B 的转换后的 .bin 文件。这些模型可以使用我们的 LLM 推理 API 直接部署。如需了解如何转换其他模型,请参阅模型转换部分。

将模型推送到设备

output_path 文件夹的内容推送到 Android 设备。

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin

创建任务

MediaPipe LLM Inference API 使用 createFromOptions() 函数来设置任务。createFromOptions() 函数接受配置选项的值。如需详细了解配置选项,请参阅配置选项

以下代码使用基本配置选项初始化任务:

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPATH('/data/local/.../')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

配置选项

使用以下配置选项设置 Android 应用:

选项名称 说明 值范围 默认值
modelPath 模型在项目目录中的存储路径。 路径 不适用
maxTokens 模型处理的词元(输入词元 + 输出词元)数量上限。 整数 512
topK 模型在生成的每个步骤中考虑的令牌数。 将预测限制为前 k 个概率最高的 token。 整数 40
temperature 生成期间引入的随机性程度。温度越高,生成的文本越富有创造力,温度越低,生成的文本越具可预测性。 浮点数 0.8
randomSeed 文本生成期间使用的随机种子。 整数 0
loraPath 设备本地 LoRA 模型的绝对路径。注意:此功能仅适用于 GPU 型号。 路径 不适用
resultListener 设置结果监听器以异步接收结果。 仅在使用异步生成方法时适用。 不适用 不适用
errorListener 设置可选的错误监听器。 不适用 不适用

准备数据

LLM Inference API 接受以下输入:

  • prompt(字符串):问题或提示。
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

运行任务

使用 generateResponse() 方法针对上一部分 (inputPrompt) 中提供的输入文本生成文本响应。这会生成单个生成的响应。

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

如需流式传输响应,请使用 generateResponseAsync() 方法。

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

处理和显示结果

LLM 推理 API 会返回一个 LlmInferenceResult,其中包含生成的回答文本。

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

LoRA 模型自定义

Mediapipe LLM Inference API 可配置为支持大语言模型的低秩自适应 (LoRA)。利用经过微调的 LoRA 模型,开发者可以通过经济高效的训练流程自定义 LLM 的行为。

LLM Inference API 的 LoRA 支持适用于 GPU 后端的所有 Gemma 变体和 Phi-2 模型,LoRA 权重仅适用于注意力层。此初始实现将作为实验性 API 用于未来开发,我们计划在即将发布的更新中支持更多模型和各种类型的图层。

准备 LoRA 模型

按照 HuggingFace 上的说明,使用支持的模型类型(Gemma 或 Phi-2)在您自己的数据集上训练经过微调的 LoRA 模型。Gemma-2 2BGemma 2BPhi-2 模型均以 safetensors 格式在 HuggingFace 上提供。由于 LLM Inference API 仅支持注意力层上的 LoRA,因此在创建 LoraConfig 时,请仅指定注意力层,如下所示:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

如需进行测试,您可以使用 HuggingFace 上提供的适用于 LLM 推理 API 的公开可用的微调 LoRA 模型。例如,对于 Gemma-2B,使用 monsterapi/gemma-2b-lora-maths-orca-200k;对于 Phi-2,使用 lole25/phi-2-sft-ultrachat-lora

使用准备好的训练数据集进行训练并保存模型后,您会获得一个包含经过微调的 LoRA 模型权重的 adapter_model.safetensors 文件。safetensors 文件是模型转换中使用的 LoRA 检查点。

下一步,您需要使用 MediaPipe Python 软件包将模型权重转换为 TensorFlow Lite Flatbuffer。ConversionConfig 应指定基本模型选项以及其他 LoRA 选项。请注意,由于该 API 仅支持使用 GPU 进行 LoRA 推理,因此后端必须设置为 'gpu'

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

转换器将输出两个 TFLite FlatBuffer 文件,一个用于基准模型,另一个用于 LoRA 模型。

LoRA 模型推理

Web、Android 和 iOS LLM 推理 API 已更新为支持 LoRA 模型推理。

Android 在初始化期间支持静态 LoRA。如需加载 LoRA 模型,用户需要指定 LoRA 模型路径以及基础 LLM。

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

如需使用 LoRA 运行 LLM 推理,请使用与基准模型相同的 generateResponse()generateResponseAsync() 方法。