Android 向け LLM 推論ガイド

LLM 推論 API を使用すると、Android アプリケーションで大規模言語モデル(LLM)を完全にオンデバイスで実行できます。これにより、テキストの生成、自然言語形式での情報の取得、ドキュメントの要約など、幅広いタスクを実行できます。このタスクには、複数のテキストからテキストへの大規模言語モデルが組み込まれているため、最新のオンデバイス生成 AI モデルを Android アプリに適用できます。

このタスクは、Gemma の Gemma-2 2B、Gemma 2B、Gemma 7B の各バリアントをサポートしています。Gemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。また、Phi-2Falcon-RW-1BStableLM-3B などの外部モデルもサポートしています。

サポートされているモデルに加えて、Google の AI Edge Torch を使用して、PyTorch モデルをマルチシグネチャ LiteRT(tflite)モデルにエクスポートできます。このモデルは、トークン化パラメータとバンドルされ、LLM 推論 API と互換性のあるタスク バンドルを作成します。

このタスクの動作は、MediaPipe Studio のデモで確認できます。このタスクの機能、モデル、構成オプションの詳細については、概要をご覧ください。

サンプルコード

このガイドでは、Android 向けの基本的なテキスト生成アプリの例について説明します。このアプリは、独自の Android アプリの開始点として使用できます。また、既存のアプリを変更する際にも参照できます。サンプルコードは GitHub でホストされています。

コードをダウンロードする

次の手順では、git コマンドライン ツールを使用してサンプルコードのローカルコピーを作成する方法について説明します。

サンプルコードをダウンロードするには:

  1. 次のコマンドを使用して、Git リポジトリのクローンを作成します。
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. 必要に応じて、LLM Inference API サンプルアプリのファイルのみが存在するように、スパース チェックアウトを使用するように git インスタンスを構成します。
    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/android
    

サンプルコードのローカル バージョンを作成したら、プロジェクトを Android Studio にインポートしてアプリを実行できます。手順については、Android のセットアップ ガイドをご覧ください。

セットアップ

このセクションでは、LLM 推論 API を使用するように開発環境とコード プロジェクトを設定する主な手順について説明します。プラットフォーム バージョンの要件など、MediaPipe タスクを使用する開発環境の設定に関する一般的な情報については、Android の設定ガイドをご覧ください。

依存関係

LLM Inference API は com.google.mediapipe:tasks-genai ライブラリを使用します。この依存関係を Android アプリの build.gradle ファイルに追加します。

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}

Android 12(API 31)以降を搭載したデバイスの場合は、ネイティブ OpenCL ライブラリの依存関係を追加します。詳細については、uses-native-library タグに関するドキュメントをご覧ください。

AndroidManifest.xml ファイルに次の uses-native-library タグを追加します。

<uses-native-library android:name="libOpenCL.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-car.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false"/>

モデル

MediaPipe LLM Inference API には、このタスクに対応したトレーニング済みのテキストツーテキスト言語モデルが必要です。モデルをダウンロードしたら、必要な依存関係をインストールし、モデルを Android デバイスに push します。Gemma 以外のモデルを使用している場合は、モデルを MediaPipe と互換性のある形式に変換する必要があります。

LLM Inference API で使用可能なトレーニング済みモデルの詳細については、タスクの概要のモデルのセクションをご覧ください。

モデルのダウンロード

LLM Inference API を初期化する前に、サポートされているモデルのいずれかをダウンロードし、プロジェクト ディレクトリ内にファイルを保存します。

  • Gemma-2 2B: Gemma ファミリーの最新バージョンのモデル。Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーの一部です。
  • Gemma 2B: Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデル ファミリーの一部です。質問応答、要約、推論など、さまざまなテキスト生成タスクに適しています。
  • Phi-2: 27 億個のパラメータを持つ Transformer モデル。質問応答、チャット、コード形式に最適です。
  • Falcon-RW-1B: RefinedWeb の 3500 億トークンでトレーニングされた 10 億パラメータの因果デコーダ専用モデル。
  • StableLM-3B: 30 億のパラメータのデコーダ専用言語モデル。多様な英語とコードのデータセットの 100 兆トークンで事前トレーニングされています。

サポートされているモデルに加えて、Google の AI Edge Torch を使用して、PyTorch モデルをマルチシグネチャ LiteRT(tflite)モデルにエクスポートできます。詳細については、PyTorch モデル用の Torch 生成コンバータをご覧ください。

Kaggle Models で入手できる Gemma-2 2B を使用することをおすすめします。使用可能な他のモデルの詳細については、タスクの概要のモデルのセクションをご覧ください。

モデルを MediaPipe 形式に変換する

LLM 推論 API は、2 つのカテゴリのモデルと互換性があります。一部のモデルではモデル変換が必要です。次の表を使用して、モデルに必要なステップ方法を特定します。

モデル コンバージョンの方法 対応プラットフォーム ファイル形式
サポートされているモデル Gemma 2B、Gemma 7B、Gemma-2 2B、Phi-2、StableLM、Falcon MediaPipe Android、iOS、ウェブ .bin
その他の PyTorch モデル すべての PyTorch LLM モデル AI Edge Torch 生成ライブラリ Android、iOS .task

Gemma 2B、Gemma 7B、Gemma-2 2B の変換済み .bin ファイルは Kaggle でホストされています。これらのモデルは、LLM Inference API を使用して直接デプロイできます。他のモデルを変換する方法については、モデルの変換をご覧ください。

モデルをデバイスにプッシュする

output_path フォルダの内容を Android デバイスにプッシュします。

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin

タスクを作成する

MediaPipe LLM Inference API は、createFromOptions() 関数を使用してタスクを設定します。createFromOptions() 関数は、構成オプションの値を受け入れます。構成オプションの詳細については、構成オプションをご覧ください。

次のコードは、基本的な構成オプションを使用してタスクを初期化します。

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPATH('/data/local/.../')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

設定オプション

次の構成オプションを使用して Android アプリを設定します。

オプション名 説明 値の範囲 デフォルト値
modelPath プロジェクト ディレクトリ内のモデルの保存場所へのパス。 経路 なし
maxTokens モデルが処理するトークンの最大数(入力トークン + 出力トークン)。 Integer 512
topK 生成の各ステップでモデルが考慮するトークンの数。予測を、最も高い確率を持つ上位 k 個のトークンに制限します。 Integer 40
temperature 生成時に導入されるランダム性の量。温度が高いほど、生成されるテキストの創造性が高まり、温度が低いほど、生成が予測しやすくなります。 浮動小数点数 0.8
randomSeed テキスト生成時に使用される乱数シード。 Integer 0
loraPath デバイス上のローカル LoRA モデルの絶対パス。注: これは GPU モデルにのみ対応しています。 経路 なし
resultListener 結果を非同期で受け取るように結果リスナーを設定します。 非同期生成方法を使用する場合にのみ適用されます。 なし なし
errorListener オプションのエラー リスナーを設定します。 なし なし

データの準備

LLM Inference API は、次の入力を受け入れます。

  • prompt(文字列): 質問またはプロンプト。
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

タスクを実行する

generateResponse() メソッドを使用して、前のセクション(inputPrompt)で指定した入力テキストに対するテキスト レスポンスを生成します。これにより、生成されたレスポンスが 1 つ生成されます。

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

レスポンスをストリーミングするには、generateResponseAsync() メソッドを使用します。

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

結果を処理して表示する

LLM Inference API は、生成されたレスポンス テキストを含む LlmInferenceResult を返します。

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

LoRA モデルのカスタマイズ

Mediapipe LLM 推論 API は、大規模言語モデルの Low-Rank Adaptation(LoRA)をサポートするように構成できます。ファインチューニングされた LoRA モデルを使用すると、デベロッパーは費用対効果の高いトレーニング プロセスで LLM の動作をカスタマイズできます。

LLM Inference API の LoRA サポートは、GPU バックエンドのすべての Gemma バリアントと Phi-2 モデルで機能します。LoRA 重みは、アテンション レイヤにのみ適用されます。この最初の実装は、今後の開発のための試験運用版 API として機能します。今後のアップデートでは、より多くのモデルとさまざまなタイプのレイヤをサポートする予定です。

LoRA モデルを準備する

HuggingFace の手順に沿って、サポートされているモデルタイプ(Gemma または Phi-2)を使用して、独自のデータセットでファインチューニングされた LoRA モデルをトレーニングします。Gemma-2 2BGemma 2BPhi-2 モデルは、HuggingFace で safetensors 形式で使用できます。LLM Inference API は、アテンション レイヤでの LoRA のみをサポートしているため、LoraConfig の作成時にアテンション レイヤのみを指定します。

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

テスト用に、HuggingFace で利用可能な LLM 推論 API に適合する、一般公開されているファインチューニング済みの LoRA モデルがあります。たとえば、Gemma-2B の場合は monsterapi/gemma-2b-lora-maths-orca-200k、Phi-2 の場合は lole25/phi-2-sft-ultrachat-lora です。

準備したデータセットでトレーニングしてモデルを保存すると、ファインチューニングされた LoRA モデルの重みを含む adapter_model.safetensors ファイルが作成されます。safetensors ファイルは、モデル変換で使用される LoRA チェックポイントです。

次のステップとして、MediaPipe Python パッケージを使用して、モデルの重みを TensorFlow Lite Flatbuffer に変換する必要があります。ConversionConfig には、ベースモデル オプションと追加の LoRA オプションを指定する必要があります。この API は GPU を使用した LoRA 推論のみをサポートしているため、バックエンドは 'gpu' に設定する必要があります。

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

コンバータは、ベースモデル用と LoRA モデル用の 2 つの TFLite FlatBuffer ファイルを出力します。

LoRA モデルの推論

ウェブ、Android、iOS の LLM 推論 API が更新され、LoRA モデル推論がサポートされるようになりました。

Android は、初期化時に静的 LoRA をサポートしています。LoRA モデルを読み込むには、LoRA モデルパスとベース LLM を指定します。

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

LoRA で LLM 推論を実行するには、ベースモデルと同じ generateResponse() メソッドまたは generateResponseAsync() メソッドを使用します。