AI Edge RAG SDK 提供了基本组件,可用于使用 LLM 推理 API 构建检索增强生成 (RAG) 流水线。RAG 流水线可让 LLM 访问用户提供的数据,其中可能包括更新后的数据、敏感数据或特定于领域的数据。借助 RAG 提供的额外信息检索功能,LLM 可以针对特定用例生成更准确、更具上下文感知的回答。
本指南将引导您使用 LLM Inference API 和 AI Edge RAG SDK 实现示例应用的基本实现。本指南重点介绍如何构建 RAG 流水线。如需详细了解如何使用 LLM 推理 API,请参阅 适用于 Android 的 LLM 推理指南。
您可以在 GitHub 上找到完整的示例应用。首先,构建应用,读取用户提供的数据 (sample_context.txt
),然后向 LLM 询问与文本文件中的信息相关的问题。
运行示例应用
本指南将介绍一个 Android 版基本文本生成应用示例,该示例使用了 RAG。您可以使用该示例应用作为自己的 Android 应用的基础,也可以在修改现有应用时参考该应用。
该应用针对 Pixel 8、Pixel 9、S23 和 S24 等高端设备进行了优化。将 Android 设备连接到工作站,并确保您使用的是当前版本的 Android Studio。如需了解详情,请参阅 Android 设置指南。
下载应用代码
以下说明介绍了如何使用 git 命令行工具创建示例代码的本地副本。
使用以下命令克隆 Git 代码库:
git clone https://github.com/google-ai-edge/ai-edge-apis
创建示例代码的本地版本后,您可以将项目导入 Android Studio 并运行应用。
下载模型
示例应用已配置为使用 Gemma-3 1B。Gemma-3 1B 属于 Gemma 系列,该系列由一系列先进的轻量级开放模型组成,这些模型采用了与 Gemini 模型相同的研究成果和技术构建而成。该模型包含 10 亿个参数和开放权重。
从 Hugging Face 下载 Gemma-3 1B 后,将模型推送到您的设备:
cd ~/Downloads
tar -xvzf gemma3-1b-it-int4.tar.gz
$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version .task
您还可以将其他模型与示例应用搭配使用,但可能需要执行额外的配置步骤。
设置嵌入程序
嵌入器会从用户提供的数据中提取文本块,并将其转换为矢量化数值表示法,以捕获其语义含义。LLM 会引用这些嵌入来识别相关的矢量,并在生成的输出中纳入语义上最相关的代码段。
此示例应用旨在与两个嵌入程序(Gemini 嵌入程序和 Gecko 嵌入程序)搭配使用。
使用 Gecko 嵌入器进行设置
默认情况下,示例应用配置为使用 Gecko 嵌入器 (GeckoEmbeddingModel
),并完全在设备端运行模型。
Gecko 嵌入器可作为浮点和量化模型提供,并针对不同的序列长度提供多个版本。如需了解详情,请参阅 Gecko 模型卡片。
您可以在模型文件名中找到模型规范。例如:
Gecko_256_fp32.tflite
:支持最多 256 个令牌的序列的浮点模型。Gecko_1024_quant.tflite
:量化模型,支持最多 1024 个 token 的序列。
序列长度是模型可以嵌入的最大分块大小。例如,如果向 Gecko_256_fp32.tflite
模型传递的块超出了序列长度,该模型将嵌入前 256 个令牌,并截断该块的其余部分。
将词解析器模型 (sentencepiece.model
) 和 Gecko 嵌入器推送到您的设备:
adb push sentencepiece.model /data/local/tmp/sentencepiece.model
adb push Gecko_256_fp32.tflite /data/local/tmp/gecko.tflite
嵌入模型与 CPU 和 GPU 兼容。默认情况下,示例应用配置为在 GPU 上使用 Gecko 模型提取嵌入。
companion object {
...
private const val USE_GPU_FOR_EMBEDDINGS = true
}
使用 Gemini 嵌入程序进行设置
Gemini 嵌入器 (GeminiEmbedder
) 使用 Gemini Cloud API 创建嵌入。这需要使用 Google Gemini API 密钥才能运行应用,您可以从 Google Gemini API 设置页面获取该密钥。
在 Google AI Studio 中获取 Gemini API 密钥
在 RagPipeline.kt 中添加您的 Gemini API 密钥并将 COMPUTE_EMBEDDINGS_LOCALLY
设为 false:
companion object {
...
private const val COMPUTE_EMBEDDINGS_LOCALLY = false
private const val GEMINI_API_KEY = "<API_KEY>"
}
运作方式
本部分详细介绍了应用的 RAG 流水线组件。您可以在 RagPipeline.kt 中查看大部分代码。
依赖项
RAG SDK 使用 com.google.ai.edge.localagents:localagents-rag
库。将以下依赖项添加到 Android 应用的 build.gradle
文件中:
dependencies {
...
implementation("com.google.ai.edge.localagents:localagents-rag:0.1.0")
implementation("com.google.mediapipe:tasks-genai:0.10.22")
}
用户提供的数据
应用中用户提供的数据是名为 sample_context.txt
的文本文件,该文件存储在 assets
目录中。该应用会获取文本文件的部分内容,创建这些部分内容的嵌入,并在生成输出文本时引用这些嵌入。
以下代码段可在 MainActivity.kt 中找到:
class MainActivity : ComponentActivity() {
lateinit var chatViewModel: ChatViewModel
...
chatViewModel.memorizeChunks("sample_context.txt")
...
}
分块
为简单起见,sample_context.txt
文件包含示例应用用于创建分块的 <chunk_splitter>
标记。然后,为每个分块创建嵌入。在生产应用中,分块大小是一项关键考虑因素。如果分块过大,矢量不够具体,无法发挥作用;如果分块过小,则不包含足够的上下文。
示例应用通过 RagPipeline.kt 中的 memorizeChunks
函数处理分块。
嵌入
该应用提供两种文本嵌入途径:
- Gecko 嵌入器:使用 Gecko 模型进行本地(设备端)文本嵌入提取。
- Gemini 嵌入器:使用生成式语言 Cloud API 在云端提取文本嵌入。
示例应用会根据用户打算在本地还是通过 Google Cloud 计算嵌入,选择嵌入器。以下代码段可在 RagPipeline.kt 中找到:
private val embedder: Embedder<String> = if (COMPUTE_EMBEDDINGS_LOCALLY) {
GeckoEmbeddingModel(
GECKO_MODEL_PATH,
Optional.of(TOKENIZER_MODEL_PATH),
USE_GPU_FOR_EMBEDDINGS,
)
} else {
GeminiEmbedder(
GEMINI_EMBEDDING_MODEL,
GEMINI_API_KEY
)
}
数据库
示例应用使用 SQLite (SqliteVectorStore
) 存储文本嵌入。您还可以将 DefaultVectorStore
数据库用于非持久性矢量存储。
您可以在 RagPipeline.kt 中找到以下代码段:
private val config = ChainConfig.create(
mediaPipeLanguageModel, PromptBuilder(QA_PROMPT_TEMPLATE1),
DefaultSemanticTextMemory(
SqliteVectorStore(768), embedder
)
)
示例应用将嵌入维度设置为 768,这表示矢量数据库中每个矢量的长度。
Chain
RAG SDK 提供了链,可将多个 RAG 组件组合到单个流水线中。您可以使用链来协调检索和查询模型。该 API 基于 Chain 接口。
示例应用使用检索和推理链。您可以在 RagPipeline.kt 中找到以下代码段:
private val retrievalAndInferenceChain = RetrievalAndInferenceChain(config)
在模型生成回答时,系统会调用该链:
suspend fun generateResponse(
prompt: String,
callback: AsyncProgressListener<LanguageModelResponse>?
): String =
coroutineScope {
val retrievalRequest =
RetrievalRequest.create(
prompt,
RetrievalConfig.create(2, 0.0f, TaskType.QUESTION_ANSWERING)
)
retrievalAndInferenceChain.invoke(retrievalRequest, callback).await().text
}