AI Edge RAG SDK 提供了基本组件,可用于通过 LLM 推理 API 构建检索增强生成 (RAG) 流水线。RAG 流水线可让 LLM 访问用户提供的数据,这些数据可能包含最新信息、敏感信息或特定领域的信息。借助 RAG 提供的额外信息检索功能,LLM 可以针对特定使用情形生成更准确、更贴合上下文的回答。
本指南将引导您使用 AI Edge RAG SDK 通过 LLM 推理 API 实现示例应用的基本功能。本指南重点介绍如何构建 RAG 流水线。如需详细了解如何使用 LLM 推理 API,请参阅 Android 版 LLM 推理指南。
您可以在 GitHub 上找到完整的示例应用。首先,构建应用,通读用户提供的数据 (sample_context.txt
),然后向 LLM 询问与文本文件中的信息相关的问题。
运行示例应用
本指南引用了一个示例,该示例展示了如何为 Android 构建具有 RAG 功能的基本文本生成应用。您可以将该示例应用作为自己的 Android 应用的起点,也可以在修改现有应用时参考该示例应用。
此应用针对 Pixel 8、Pixel 9、S23 和 S24 等高端设备进行了优化。将 Android 设备连接到工作站,并确保您使用的是最新版本的 Android Studio。如需了解详情,请参阅 Android 设置指南。
下载应用代码
以下说明展示了如何使用 git 命令行工具创建示例代码的本地副本。
使用以下命令克隆 Git 代码库:
git clone https://github.com/google-ai-edge/ai-edge-apis
在创建示例代码的本地版本后,您可以将项目导入 Android Studio 并运行应用。
下载模型
示例应用已配置为使用 Gemma-3 1B。Gemma-3 1B 是一系列先进的轻量级开放模型 Gemma 的一员,基于用于创建 Gemini 模型的研究和技术构建而成。该模型包含 10 亿个参数和开放权重。
从 Hugging Face 下载 Gemma-3 1B 后,将模型推送到您的设备:
cd ~/Downloads
tar -xvzf gemma3-1b-it-int4.tar.gz
$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.task
您还可以将其他模型与示例应用搭配使用,但可能需要执行额外的配置步骤。
设置嵌入器
嵌入器会从用户提供的数据中提取文本块,并将其转换为可捕捉语义含义的矢量化数值表示形式。LLM 会参考这些嵌入来识别相关向量,并将语义相关性最高的块纳入生成的输出中。
此示例应用旨在与两个嵌入器(Gemini 嵌入器和 Gecko 嵌入器)搭配使用。
使用 Gecko 嵌入器进行设置
默认情况下,示例应用配置为使用 Gecko 嵌入器 (GeckoEmbeddingModel
),并在设备上完全运行模型。
Gecko 嵌入器提供浮点模型和量化模型,并有多个版本,可支持不同的序列长度。如需了解详情,请参阅 Gecko 模型卡片。
您可以在模型文件名中找到模型规范。例如:
Gecko_256_f32.tflite
:支持最多 256 个 token 序列的浮点模型。Gecko_1024_quant.tflite
:支持最多 1024 个 token 序列的量化模型。
序列长度是指模型可以嵌入的最大块大小。例如,如果传递给 Gecko_256_f32.tflite
模型的块超过了序列长度,模型将嵌入前 256 个 token,并截断块的其余部分。
将词元化器模型 (sentencepiece.model
) 和 Gecko 嵌入器推送到设备:
adb push sentencepiece.model /data/local/tmp/sentencepiece.model
adb push Gecko_256_f32.tflite /data/local/tmp/gecko.tflite
嵌入模型与 CPU 和 GPU 都兼容。默认情况下,示例应用配置为使用 GPU 上的 Gecko 模型提取嵌入内容。
companion object {
...
private const val USE_GPU_FOR_EMBEDDINGS = true
}
使用 Gemini Embedder 进行设置
Gemini Embedder (GeminiEmbedder
) 使用 Gemini Cloud API 创建嵌入。这需要使用 Google Gemini API 密钥才能运行应用,您可以从 Google Gemini API 设置页面获取该密钥。
在 Google AI Studio 中获取 Gemini API 密钥
在 RagPipeline.kt 中添加 Gemini API 密钥并将 COMPUTE_EMBEDDINGS_LOCALLY
设置为 false:
companion object {
...
private const val COMPUTE_EMBEDDINGS_LOCALLY = false
private const val GEMINI_API_KEY = "<API_KEY>"
}
运作方式
本部分更深入地介绍了应用的 RAG 流水线组件。您可以在 RagPipeline.kt 中查看大部分代码。
依赖项
RAG SDK 使用 com.google.ai.edge.localagents:localagents-rag
库。将此依赖项添加到 Android 应用的 build.gradle
文件中:
dependencies {
...
implementation("com.google.ai.edge.localagents:localagents-rag:0.1.0")
implementation("com.google.mediapipe:tasks-genai:0.10.22")
}
用户提供的数据
应用中用户提供的数据是一个名为 sample_context.txt
的文本文件,存储在 assets
目录中。该应用会获取文本文件的各个块,创建这些块的嵌入,并在生成输出文本时参考这些嵌入。
以下代码段可在 MainActivity.kt 中找到:
class MainActivity : ComponentActivity() {
lateinit var chatViewModel: ChatViewModel
...
chatViewModel.memorizeChunks("sample_context.txt")
...
}
分块
为简单起见,sample_context.txt
文件包含示例应用用于创建块的 <chunk_splitter>
标记。然后,为每个分块创建嵌入。在生产应用中,块的大小是一项关键考虑因素。如果块过大,矢量就无法包含足够的特异性,从而无法发挥作用;如果块过小,矢量就无法包含足够的上下文。
示例应用通过 RagPipeline.kt 中的 memorizeChunks
函数处理分块。
嵌入
该应用提供两种文本嵌入途径:
- Gecko 嵌入器:使用 Gecko 模型在本地(设备上)提取文本嵌入。
- Gemini Embedder:使用 Generative Language Cloud API 进行基于云的文本嵌入提取。
示例应用会根据用户是打算在本地还是通过 Google Cloud 计算嵌入内容来选择嵌入器。以下代码段可在 RagPipeline.kt 中找到:
private val embedder: Embedder<String> = if (COMPUTE_EMBEDDINGS_LOCALLY) {
GeckoEmbeddingModel(
GECKO_MODEL_PATH,
Optional.of(TOKENIZER_MODEL_PATH),
USE_GPU_FOR_EMBEDDINGS,
)
} else {
GeminiEmbedder(
GEMINI_EMBEDDING_MODEL,
GEMINI_API_KEY
)
}
数据库
此示例应用使用 SQLite (SqliteVectorStore
) 来存储文本嵌入内容。您还可以使用 DefaultVectorStore
数据库进行非持久性向量存储。
以下代码段可在 RagPipeline.kt 中找到:
private val config = ChainConfig.create(
mediaPipeLanguageModel, PromptBuilder(QA_PROMPT_TEMPLATE1),
DefaultSemanticTextMemory(
SqliteVectorStore(768), embedder
)
)
示例应用将嵌入维度设置为 768,这指的是向量数据库中每个向量的长度。
Chain
RAG SDK 提供链,可将多个 RAG 组件组合成单个流水线。您可以使用链来编排检索和查询模型。该 API 基于 Chain 接口。
示例应用使用检索和推理链。 以下代码段可在 RagPipeline.kt 中找到:
private val retrievalAndInferenceChain = RetrievalAndInferenceChain(config)
当模型生成回答时,系统会调用链:
suspend fun generateResponse(
prompt: String,
callback: AsyncProgressListener<LanguageModelResponse>?
): String =
coroutineScope {
val retrievalRequest =
RetrievalRequest.create(
prompt,
RetrievalConfig.create(2, 0.0f, TaskType.QUESTION_ANSWERING)
)
retrievalAndInferenceChain.invoke(retrievalRequest, callback).await().text
}