本指南提供操作說明,將 Hugging Face Safetensors 格式 (.safetensors
) 的 Gemma 模型轉換為 MediaPipe Task 檔案格式 (.task
)。這項轉換作業至關重要,因為您必須使用 MediaPipe LLM Inference API 和 LiteRT 執行階段,才能在 Android 和 iOS 裝置上部署預先訓練或微調的 Gemma 模型,進行裝置端推論。
如要建立必要的工作套件 (.task
),請使用 AI Edge Torch。這項工具會將 PyTorch 模型匯出至多重簽章 LiteRT (.tflite
) 模型,這類模型與 MediaPipe LLM 推論 API 相容,適合在行動應用程式的 CPU 後端上執行。
最終的 .task
檔案是 MediaPipe 要求的獨立套件,其中包含 LiteRT 模型、權杖化工具模型和必要的中繼資料。這是因為權杖化工具 (可將文字提示轉換為模型適用的權杖嵌入) 必須與 LiteRT 模型一併封裝,才能進行端對端推論。
以下是逐步操作說明:
1. 取得 Gemma 模型
你可以透過兩種方式開始使用。
選項 A:使用現有的微調模型
如果您已準備好微調 Gemma 模型,請直接進行下一個步驟。
選項 B:下載官方指令調整模型
如需模型,可以從 Hugging Face Hub 下載經過指令微調的 Gemma。
設定必要工具:
python -m venv hf
source hf/bin/activate
pip install huggingface_hub[cli]
下載模型:
Hugging Face Hub 中的模型會以模型 ID 識別,通常採用 <organization_or_username>/<model_name>
格式。舉例來說,如要下載官方 Google Gemma 3 270M 指令微調模型,請使用:
hf download google/gemma-3-270m-it --local-dir "PATH_TO_HF_MODEL"
#"google/gemma-3-1b-it", etc
2. 將模型轉換並量化為 LiteRT
設定 Python 虛擬環境,並安裝 AI Edge Torch 套件的最新穩定版本:
python -m venv ai-edge-torch
source ai-edge-torch/bin/activate
pip install "ai-edge-torch>=0.6.0"
使用下列指令碼將 Safetensor 轉換為 LiteRT 模型。
from ai_edge_torch.generative.examples.gemma3 import gemma3
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.export_config import ExportConfig
from ai_edge_torch.generative.layers import kv_cache
pytorch_model = gemma3.build_model_270m("PATH_TO_HF_MODEL")
# If you are using Gemma 3 1B
#pytorch_model = gemma3.build_model_1b("PATH_TO_HF_MODEL")
export_config = ExportConfig()
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
export_config.mask_as_input = True
converter.convert_to_tflite(
pytorch_model,
output_path="OUTPUT_DIR_PATH",
output_name_prefix="my-gemma3",
prefill_seq_len=2048,
kv_cache_max_len=4096,
quantize="dynamic_int8",
export_config=export_config,
)
請注意,這項程序耗時較長,且取決於電腦的處理速度。以 2025 年的 8 核心 CPU 為例,2.7 億個參數的模型需要 5 到 10 分鐘,10 億個參數的模型則需要 10 到 30 分鐘。
最終輸出內容 (LiteRT 模型) 會儲存至您指定的 OUTPUT_DIR_PATH
。
請根據目標裝置的記憶體和效能限制調整下列值。
kv_cache_max_len
:定義模型工作記憶體 (KV 快取) 的總分配大小。這項容量是硬性限制,必須足以儲存提示的權杖總和 (預先填入) 和所有後續產生的權杖 (解碼)。prefill_seq_len
:指定預先填入分塊的輸入提示符記數。使用預填區塊處理輸入提示時,整個序列 (例如50,000 個權杖) 不會一次計算,而是會分成可管理的區段 (例如 2,048 個權杖的區塊),依序載入快取,避免發生記憶體不足錯誤。quantize
:所選量化配置的字串。以下列出 Gemma 3 適用的量化配方。none
:無量化fp16
:所有作業的 FP16 權重、FP32 啟用和浮點運算dynamic_int8
:FP32 啟用、INT8 權重和整數運算weight_only_int8
:FP32 啟用、INT8 權重和浮點運算
3. 從 LiteRT 和權杖化工具建立 Task Bundle
設定 Python 虛擬環境並安裝 mediapipe Python 套件:
python -m venv mediapipe
source mediapipe/bin/activate
pip install mediapipe
使用 genai.bundler
程式庫將模型套裝組合:
from mediapipe.tasks.python.genai import bundler
config = bundler.BundleConfig(
tflite_model="PATH_TO_LITERT_MODEL.tflite",
tokenizer_model="PATH_TO_HF_MODEL/tokenizer.model",
start_token="<bos>",
stop_tokens=["<eos>", "<end_of_turn>"],
output_filename="PATH_TO_TASK_BUNDLE.task",
prompt_prefix="<start_of_turn>user\n",
prompt_suffix="<end_of_turn>\n<start_of_turn>model\n",
)
bundler.create_bundle(config)
bundler.create_bundle
函式會建立 .task
檔案,其中包含執行模型所需的所有資訊。
4. 在 Android 上使用 Mediapipe 進行推論
使用基本設定選項初始化工作:
// Default values for LLM models
private object LLMConstants {
const val MODEL_PATH = "PATH_TO_TASK_BUNDLE_ON_YOUR_DEVICE.task"
const val DEFAULT_MAX_TOKEN = 4096
const val DEFAULT_TOPK = 64
const val DEFAULT_TOPP = 0.95f
const val DEFAULT_TEMPERATURE = 1.0f
}
// Set the configuration options for the LLM Inference task
val taskOptions = LlmInference.LlmInferenceOptions.builder()
.setModelPath(LLMConstants.MODEL_PATH)
.setMaxTokens(LLMConstants.DEFAULT_MAX_TOKEN)
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, taskOptions)
llmInferenceSession =
LlmInferenceSession.createFromOptions(
llmInference,
LlmInferenceSession.LlmInferenceSessionOptions.builder()
.setTopK(LLMConstants.DEFAULT_TOPK)
.setTopP(LLMConstants.DEFAULT_TOPP)
.setTemperature(LLMConstants.DEFAULT_TEMPERATURE)
.build(),
)
使用 generateResponse()
方法生成文字回覆。
val result = llmInferenceSession.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")
如要串流回應,請使用 generateResponseAsync()
方法。
llmInferenceSession.generateResponseAsync(inputPrompt) { partialResult, done ->
logger.atInfo().log("partial result: $partialResult")
}
詳情請參閱 Android 適用的 LLM 推論指南。
後續步驟
使用 Gemma 模型建構及探索更多內容: