LLM Inference API を使用すると、ウェブ アプリケーションのブラウザで大規模言語モデル(LLM)を完全に実行できます。これにより、テキストの生成、自然言語形式での情報の取得、ドキュメントの要約など、幅広いタスクを実行できます。このタスクには、複数のテキストからテキストへの大規模言語モデルが組み込まれているため、最新のオンデバイス生成 AI モデルをウェブアプリに適用できます。
このタスクは、Gemma の Gemma-2 2B、Gemma 2B、Gemma 7B の各バリアントをサポートしています。Gemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。また、Phi-2、Falcon-RW-1B、StableLM-3B などの外部モデルもサポートしています。
このタスクの動作は、MediaPipe Studio のデモで確認できます。このタスクの機能、モデル、構成オプションの詳細については、概要をご覧ください。
サンプルコード
LLM 推論 API のサンプル アプリケーションには、このタスクの基本的な JavaScript 実装が用意されています。このサンプルアプリを使用して、独自のテキスト生成アプリの作成を開始できます。
LLM 推論 API のサンプルアプリは、GitHub で入手できます。
セットアップ
このセクションでは、LLM 推論 API を使用するように開発環境とコード プロジェクトを設定する主な手順について説明します。プラットフォーム バージョンの要件など、MediaPipe Tasks を使用する開発環境の設定に関する一般的な情報については、ウェブ向けの設定ガイドをご覧ください。
ブラウザの互換性
LLM 推論 API を使用するには、WebGPU に対応したウェブブラウザが必要です。互換性のあるブラウザの一覧については、GPU ブラウザの互換性をご覧ください。
JavaScript パッケージ
LLM 推論 API コードは @mediapipe/tasks-genai
パッケージで使用できます。これらのライブラリは、プラットフォームの設定ガイドに記載されているリンクから確認してダウンロードできます。
ローカル ステージングに必要なパッケージをインストールします。
npm install @mediapipe/tasks-genai
サーバーにデプロイするには、jsDelivr などのコンテンツ配信ネットワーク(CDN)サービスを使用して、HTML ページにコードを直接追加します。
<head>
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/genai_bundle.cjs"
crossorigin="anonymous"></script>
</head>
モデル
MediaPipe LLM Inference API には、このタスクに対応したトレーニング済みモデルが必要です。ウェブ アプリケーションの場合、モデルは GPU と互換性がある必要があります。
LLM Inference API で使用可能なトレーニング済みモデルの詳細については、タスクの概要のモデルのセクションをご覧ください。
モデルのダウンロード
LLM Inference API を初期化する前に、サポートされているモデルのいずれかをダウンロードし、プロジェクト ディレクトリ内にファイルを保存します。
- Gemma-2 2B: Gemma ファミリーの最新バージョンのモデル。Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーの一部です。
- Gemma 2B: Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデル ファミリーの一部です。質問応答、要約、推論など、さまざまなテキスト生成タスクに適しています。
- Phi-2: 27 億個のパラメータを持つ Transformer モデル。質問応答、チャット、コード形式に最適です。
- Falcon-RW-1B: RefinedWeb の 3500 億トークンでトレーニングされた 10 億パラメータの因果デコーダ専用モデル。
- StableLM-3B: 30 億のパラメータのデコーダ専用言語モデル。多様な英語とコードのデータセットの 100 兆トークンで事前トレーニングされています。
サポートされているモデルに加えて、Google の AI Edge Torch を使用して、PyTorch モデルをマルチシグネチャ LiteRT(tflite
)モデルにエクスポートできます。詳細については、PyTorch モデル用の Torch 生成コンバータをご覧ください。
Kaggle Models で入手できる Gemma-2 2B を使用することをおすすめします。使用可能な他のモデルの詳細については、タスクの概要のモデルのセクションをご覧ください。
モデルを MediaPipe 形式に変換する
LLM 推論 API は、2 つのカテゴリのモデルと互換性があります。一部のモデルではモデル変換が必要です。次の表を使用して、モデルに必要なステップ方法を特定します。
モデル | コンバージョンの方法 | 対応プラットフォーム | ファイル形式 | |
---|---|---|---|---|
サポートされているモデル | Gemma 2B、Gemma 7B、Gemma-2 2B、Phi-2、StableLM、Falcon | MediaPipe | Android、iOS、ウェブ | .bin |
その他の PyTorch モデル | すべての PyTorch LLM モデル | AI Edge Torch 生成ライブラリ | Android、iOS | .task |
Gemma 2B、Gemma 7B、Gemma-2 2B の変換済み .bin
ファイルは Kaggle でホストされています。これらのモデルは、LLM Inference API を使用して直接デプロイできます。他のモデルを変換する方法については、モデルの変換をご覧ください。
モデルをプロジェクト ディレクトリに追加する
モデルをプロジェクト ディレクトリに保存します。
<dev-project-root>/assets/gemma-2b-it-gpu-int4.bin
baseOptions
オブジェクトの modelAssetPath
パラメータを使用して、モデルのパスを指定します。
baseOptions: { modelAssetPath: `/assets/gemma-2b-it-gpu-int4.bin`}
タスクを作成する
LLM Inference API の createFrom...()
関数のいずれかを使用して、推論を実行するタスクを準備します。createFromModelPath()
関数は、トレーニング済みモデル ファイルの相対パスまたは絶対パスで使用できます。このコードサンプルでは、createFromOptions()
関数を使用しています。使用可能な構成オプションの詳細については、構成オプションをご覧ください。
次のコードは、このタスクをビルドして構成する方法を示しています。
const genai = await FilesetResolver.forGenAiTasks(
// path/to/wasm/root
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
llmInference = await LlmInference.createFromOptions(genai, {
baseOptions: {
modelAssetPath: '/assets/gemma-2b-it-gpu-int4.bin'
},
maxTokens: 1000,
topK: 40,
temperature: 0.8,
randomSeed: 101
});
設定オプション
このタスクには、ウェブアプリと JavaScript アプリ用の次の構成オプションがあります。
オプション名 | 説明 | 値の範囲 | デフォルト値 |
---|---|---|---|
modelPath |
プロジェクト ディレクトリ内のモデルの保存場所へのパス。 | 経路 | なし |
maxTokens |
モデルが処理するトークンの最大数(入力トークン + 出力トークン)。 | Integer | 512 |
topK |
生成の各ステップでモデルが考慮するトークンの数。予測を、最も高い確率を持つ上位 k 個のトークンに制限します。 | Integer | 40 |
temperature |
生成時に導入されるランダム性の量。温度が高いほど、生成されるテキストの創造性が高まり、温度が低いほど、生成が予測しやすくなります。 | 浮動小数点数 | 0.8 |
randomSeed |
テキスト生成時に使用される乱数シード。 | Integer | 0 |
loraRanks |
ランタイム中に LoRA モデルで使用される LoRA ランク。注: これは GPU モデルにのみ対応しています。 | 整数配列 | なし |
データの準備
LLM Inference API はテキスト(string
)データを受け入れます。このタスクは、トークン化やテンソルの前処理など、データ入力の前処理を処理します。
すべての前処理は generateResponse()
関数内で処理されます。入力テキストの追加の前処理は必要ありません。
const inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday.";
タスクを実行する
LLM Inference API は、generateResponse()
関数を使用して推論をトリガーします。テキスト分類の場合、これは入力テキストの可能性のあるカテゴリを返すことを意味します。
次のコードは、タスクモデルを使用して処理を実行する方法を示しています。
const response = await llmInference.generateResponse(inputPrompt);
document.getElementById('output').textContent = response;
レスポンスをストリーミングするには、次のコマンドを使用します。
llmInference.generateResponse(
inputPrompt,
(partialResult, done) => {
document.getElementById('output').textContent += partialResult;
});
結果を処理して表示する
LLM Inference API は、生成されたレスポンス テキストを含む文字列を返します。
Here's a draft you can use:
Subject: Lunch on Saturday Reminder
Hi Brett,
Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.
Looking forward to it!
Best,
[Your Name]
LoRA モデルのカスタマイズ
Mediapipe LLM 推論 API は、大規模言語モデルの Low-Rank Adaptation(LoRA)をサポートするように構成できます。ファインチューニングされた LoRA モデルを使用すると、デベロッパーは費用対効果の高いトレーニング プロセスで LLM の動作をカスタマイズできます。
LLM Inference API の LoRA サポートは、GPU バックエンドのすべての Gemma バリアントと Phi-2 モデルで機能します。LoRA 重みは、アテンション レイヤにのみ適用されます。この最初の実装は、今後の開発のための試験運用版 API として機能します。今後のアップデートでは、より多くのモデルとさまざまなタイプのレイヤをサポートする予定です。
LoRA モデルを準備する
HuggingFace の手順に沿って、サポートされているモデルタイプ(Gemma または Phi-2)を使用して、独自のデータセットでファインチューニングされた LoRA モデルをトレーニングします。Gemma-2 2B、Gemma 2B、Phi-2 モデルは、HuggingFace で safetensors 形式で使用できます。LLM Inference API は、アテンション レイヤでの LoRA のみをサポートしているため、LoraConfig
の作成時にアテンション レイヤのみを指定します。
# For Gemma
from peft import LoraConfig
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# For Phi-2
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)
テスト用に、HuggingFace で利用可能な LLM 推論 API に適合する、一般公開されているファインチューニング済みの LoRA モデルがあります。たとえば、Gemma-2B の場合は monsterapi/gemma-2b-lora-maths-orca-200k、Phi-2 の場合は lole25/phi-2-sft-ultrachat-lora です。
準備したデータセットでトレーニングしてモデルを保存すると、ファインチューニングされた LoRA モデルの重みを含む adapter_model.safetensors
ファイルが作成されます。safetensors ファイルは、モデル変換で使用される LoRA チェックポイントです。
次のステップとして、MediaPipe Python パッケージを使用して、モデルの重みを TensorFlow Lite Flatbuffer に変換する必要があります。ConversionConfig
には、ベースモデル オプションと追加の LoRA オプションを指定する必要があります。この API は GPU を使用した LoRA 推論のみをサポートしているため、バックエンドは 'gpu'
に設定する必要があります。
import mediapipe as mp
from mediapipe.tasks.python.genai import converter
config = converter.ConversionConfig(
# Other params related to base model
...
# Must use gpu backend for LoRA conversion
backend='gpu',
# LoRA related params
lora_ckpt=LORA_CKPT,
lora_rank=LORA_RANK,
lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)
converter.convert_checkpoint(config)
コンバータは、ベースモデル用と LoRA モデル用の 2 つの TFLite FlatBuffer ファイルを出力します。
LoRA モデルの推論
ウェブ、Android、iOS の LLM 推論 API が更新され、LoRA モデル推論がサポートされるようになりました。
ウェブは実行時に動的 LoRA をサポートします。つまり、ユーザーは初期化時に使用する LoRA ランクを宣言し、実行時に異なる LoRA モデルをスワップできます。const genai = await FilesetResolver.forGenAiTasks(
// path/to/wasm/root
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
const llmInference = await LlmInference.createFromOptions(genai, {
// options for the base model
...
// LoRA ranks to be used by the LoRA models during runtime
loraRanks: [4, 8, 16]
});
実行時に、ベースモデルが初期化された後、使用する LoRA モデルを読み込みます。また、LLM レスポンスを生成するときに LoRA モデル参照を渡して、LoRA モデルをトリガーします。
// Load several LoRA models. The returned LoRA model reference is used to specify
// which LoRA model to be used for inference.
loraModelRank4 = await llmInference.loadLoraModel(loraModelRank4Url);
loraModelRank8 = await llmInference.loadLoraModel(loraModelRank8Url);
// Specify LoRA model to be used during inference
llmInference.generateResponse(
inputPrompt,
loraModelRank4,
(partialResult, done) => {
document.getElementById('output').textContent += partialResult;
});