Hugging Face Safetensors を MediaPipe Task に変換する

このガイドでは、Hugging Face Safetensors 形式(.safetensors)の Gemma モデルを MediaPipe タスク ファイル形式(.task)に変換する手順について説明します。この変換は、MediaPipe LLM 推論 API と LiteRT ランタイムを使用して、Android と iOS でオンデバイス推論用に事前トレーニング済みまたはファインチューニング済みの Gemma モデルをデプロイするために不可欠です。

Hugging Face モデルを MediaPipe タスクファイルにパッケージ化するフローチャート

必要なタスク バンドル(.task)を作成するには、AI Edge Torch を使用します。このツールは、PyTorch モデルをマルチシグネチャ LiteRT(.tflite)モデルにエクスポートします。このモデルは MediaPipe LLM 推論 API と互換性があり、モバイル アプリケーションの CPU バックエンドでの実行に適しています。

最終的な .task ファイルは、MediaPipe で必要な自己完結型のパッケージで、LiteRT モデル、トークナイザー モデル、重要なメタデータがバンドルされています。このバンドルが必要なのは、エンドツーエンドの推論を可能にするために、トークナイザー(テキスト プロンプトをモデルのトークン エンベディングに変換する)を LiteRT モデルとともにパッケージ化する必要があるためです。

プロセスの手順は次のとおりです。

プロセスのステップバイステップの内訳

1. Gemma モデルを入手する

開始するには、次の 2 つの方法があります。

オプション A。既存のファインチューニング済みモデルを使用する

ファインチューニングされた Gemma モデルが準備できている場合は、次のステップに進みます。

オプション B。公式の指示用チューニング済みモデルをダウンロードする

モデルが必要な場合は、Hugging Face Hub から指示に従って調整された Gemma をダウンロードできます。

必要なツールをセットアップします。

python -m venv hf
source hf/bin/activate
pip install huggingface_hub[cli]

モデルをダウンロードする:

Hugging Face Hub のモデルは、通常 <organization_or_username>/<model_name> 形式のモデル ID で識別されます。たとえば、Google の公式 Gemma 3 270M 指示チューニング済みモデルをダウンロードするには、次のコマンドを使用します。

hf download google/gemma-3-270m-it --local-dir "PATH_TO_HF_MODEL"
#"google/gemma-3-1b-it", etc

2. モデルを LiteRT に変換して量子化する

Python 仮想環境を設定し、AI Edge Torch パッケージの最新の安定版リリースをインストールします。

python -m venv ai-edge-torch
source ai-edge-torch/bin/activate
pip install "ai-edge-torch>=0.6.0"

次のスクリプトを使用して、Safetensor を LiteRT モデルに変換します。

from ai_edge_torch.generative.examples.gemma3 import gemma3
from ai_edge_torch.generative.utilities import converter
from ai_edge_torch.generative.utilities.export_config import ExportConfig
from ai_edge_torch.generative.layers import kv_cache

pytorch_model = gemma3.build_model_270m("PATH_TO_HF_MODEL")

# If you are using Gemma 3 1B
#pytorch_model = gemma3.build_model_1b("PATH_TO_HF_MODEL")

export_config = ExportConfig()
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
export_config.mask_as_input = True

converter.convert_to_tflite(
    pytorch_model,
    output_path="OUTPUT_DIR_PATH",
    output_name_prefix="my-gemma3",
    prefill_seq_len=2048,
    kv_cache_max_len=4096,
    quantize="dynamic_int8",
    export_config=export_config,
)

このプロセスには時間がかかり、パソコンの処理速度によって異なります。参考までに、2025 年の 8 コア CPU では、2 億 7,000 万モデルに 5 ~ 10 分以上かかり、10 億モデルには 10 ~ 30 分ほどかかることがあります。

最終出力である LiteRT モデルは、指定した OUTPUT_DIR_PATH に保存されます。

ターゲット デバイスのメモリとパフォーマンスの制約に基づいて、次の値を調整します。

  • kv_cache_max_len: モデルのワーキング メモリ(KV キャッシュ)の割り当て済み合計サイズを定義します。この容量は上限であり、プロンプトのトークン(プリフィル)と後続のすべての生成トークン(デコード)の合計を保存するのに十分な容量が必要です。
  • prefill_seq_len: プリフィル チャンクの入力プロンプトのトークン数を指定します。プリフィル チャンクを使用して入力プロンプトを処理する場合、シーケンス全体(50,000 個のトークン)は一度に計算されません。代わりに、管理可能なセグメント(2,048 個のトークンのチャンクなど)に分割され、メモリ不足エラーを防ぐためにキャッシュに順番に読み込まれます。
  • quantize: 選択した量子化スキームの文字列。Gemma 3 で使用可能な量子化レシピのリストは次のとおりです。
    • none : 量子化なし
    • fp16 : すべてのオペレーションで FP16 の重み、FP32 のアクティベーション、浮動小数点計算
    • dynamic_int8 : FP32 アクティベーション、INT8 重み、整数計算
    • weight_only_int8 : FP32 アクティベーション、INT8 重み、浮動小数点計算

3. LiteRT とトークナイザーから Task Bundle を作成する

Python 仮想環境を設定し、mediapipe Python パッケージをインストールします。

python -m venv mediapipe
source mediapipe/bin/activate
pip install mediapipe

genai.bundler ライブラリを使用してモデルをバンドルします。

from mediapipe.tasks.python.genai import bundler
config = bundler.BundleConfig(
    tflite_model="PATH_TO_LITERT_MODEL.tflite",
    tokenizer_model="PATH_TO_HF_MODEL/tokenizer.model",
    start_token="<bos>",
    stop_tokens=["<eos>", "<end_of_turn>"],
    output_filename="PATH_TO_TASK_BUNDLE.task",
    prompt_prefix="<start_of_turn>user\n",
    prompt_suffix="<end_of_turn>\n<start_of_turn>model\n",
)
bundler.create_bundle(config)

bundler.create_bundle 関数は、モデルの実行に必要なすべての情報を含む .task ファイルを作成します。

4. Android での Mediapipe を使用した推論

基本的な構成オプションを使用してタスクを初期化します。

// Default values for LLM models
private object LLMConstants {
    const val MODEL_PATH = "PATH_TO_TASK_BUNDLE_ON_YOUR_DEVICE.task"
    const val DEFAULT_MAX_TOKEN = 4096
    const val DEFAULT_TOPK = 64
    const val DEFAULT_TOPP = 0.95f
    const val DEFAULT_TEMPERATURE = 1.0f
}

// Set the configuration options for the LLM Inference task
val taskOptions = LlmInference.LlmInferenceOptions.builder()
    .setModelPath(LLMConstants.MODEL_PATH)
    .setMaxTokens(LLMConstants.DEFAULT_MAX_TOKEN)
    .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, taskOptions)
llmInferenceSession =
    LlmInferenceSession.createFromOptions(
        llmInference,
        LlmInferenceSession.LlmInferenceSessionOptions.builder()
            .setTopK(LLMConstants.DEFAULT_TOPK)
            .setTopP(LLMConstants.DEFAULT_TOPP)
            .setTemperature(LLMConstants.DEFAULT_TEMPERATURE)
            .build(),
    )

generateResponse() メソッドを使用してテキスト レスポンスを生成します。

val result = llmInferenceSession.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

レスポンスをストリーミングするには、generateResponseAsync() メソッドを使用します。

llmInferenceSession.generateResponseAsync(inputPrompt) { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
}

詳しくは、Android 向け LLM 推論ガイドをご覧ください。

次のステップ

Gemma モデルでさらに構築して探索する: