iOS 向け LLM 推論ガイド

LLM Inference API を使用すると、iOS アプリケーションで大規模言語モデル(LLM)を完全にオンデバイスで実行できます。これにより、テキストの生成、自然言語形式での情報の取得、ドキュメントの要約など、幅広いタスクを実行できます。このタスクには、複数のテキストからテキストへの大規模言語モデルが組み込まれているため、最新のオンデバイス生成 AI モデルを iOS アプリに適用できます。

このタスクは、Gemma の Gemma-2 2B、Gemma 2B、Gemma 7B の各バリアントをサポートしています。Gemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。また、Phi-2Falcon-RW-1BStableLM-3B などの外部モデルもサポートしています。

サポートされているモデルに加えて、Google の AI Edge Torch を使用して、PyTorch モデルをマルチシグネチャ LiteRT(tflite)モデルにエクスポートできます。このモデルは、トークン化パラメータとバンドルされ、LLM 推論 API と互換性のあるタスク バンドルを作成します。

このタスクの動作は、MediaPipe Studio のデモで確認できます。このタスクの機能、モデル、構成オプションの詳細については、概要をご覧ください。

サンプルコード

MediaPipe Tasks のサンプルコードは、iOS 向け LLM 推論 API アプリの基本的な実装です。このアプリは、独自の iOS アプリの開始点として使用できます。また、既存のアプリを変更する際に参照することもできます。LLM 推論 API のサンプルコードは GitHub でホストされています。

コードをダウンロードする

次の手順では、git コマンドライン ツールを使用してサンプルコードのローカルコピーを作成する方法について説明します。

サンプルコードをダウンロードするには:

  1. 次のコマンドを使用して、Git リポジトリのクローンを作成します。

    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. 必要に応じて、LLM Inference API サンプルアプリのファイルのみが取得されるように、スパース チェックアウトを使用するように git インスタンスを構成します。

    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/ios/
    

ローカル バージョンのサンプルコードを作成したら、MediaPipe タスク ライブラリをインストールし、Xcode を使用してプロジェクトを開いてアプリを実行できます。手順については、iOS 用セットアップガイドをご覧ください。

セットアップ

このセクションでは、LLM Inference API を使用するように開発環境とコード プロジェクトを設定する主な手順について説明します。プラットフォーム バージョンの要件など、MediaPipe タスクを使用する開発環境の設定に関する一般的な情報については、iOS 用セットアップ ガイドをご覧ください。

依存関係

LLM Inference API は MediaPipeTasksGenai ライブラリを使用します。このライブラリは CocoaPods を使用してインストールする必要があります。このライブラリは Swift アプリと Objective-C アプリの両方に対応しており、言語固有の追加設定は必要ありません。

macOS に CocoaPods をインストールする手順については、CocoaPods インストール ガイドをご覧ください。アプリに必要な Pod を使用して Podfile を作成する方法については、CocoaPods の使用をご覧ください。

次のコードを使用して、PodfileMediaPipeTasksGenai Pod を追加します。

target 'MyLlmInferenceApp' do
  use_frameworks!
  pod 'MediaPipeTasksGenAI'
  pod 'MediaPipeTasksGenAIC'
end

アプリに単体テスト ターゲットが含まれている場合は、Podfile の設定について詳しくは、iOS 用セットアップ ガイドをご覧ください。

モデル

MediaPipe LLM Inference API タスクには、このタスクと互換性のあるトレーニング済みモデルが必要です。LLM Inference API で使用可能なトレーニング済みモデルの詳細については、タスクの概要のモデルのセクションをご覧ください。

モデルのダウンロード

モデルをダウンロードし、Xcode を使用してプロジェクト ディレクトリに追加します。Xcode プロジェクトにファイルを追加する方法については、Xcode プロジェクト内のファイルとフォルダの管理をご覧ください。

LLM Inference API を初期化する前に、サポートされているモデルのいずれかをダウンロードし、プロジェクト ディレクトリ内にファイルを保存します。

  • Gemma-2 2B: Gemma ファミリーの最新バージョンのモデル。Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーの一部です。
  • Gemma 2B: Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデル ファミリーの一部です。質問応答、要約、推論など、さまざまなテキスト生成タスクに適しています。
  • Phi-2: 27 億個のパラメータを持つ Transformer モデル。質問応答、チャット、コード形式に最適です。
  • Falcon-RW-1B: RefinedWeb の 3500 億トークンでトレーニングされた 10 億パラメータの因果デコーダ専用モデル。
  • StableLM-3B: 30 億のパラメータのデコーダ専用言語モデル。多様な英語とコードのデータセットの 100 兆トークンで事前トレーニングされています。

サポートされているモデルに加えて、Google の AI Edge Torch を使用して、PyTorch モデルをマルチシグネチャ LiteRT(tflite)モデルにエクスポートできます。詳細については、PyTorch モデル用の Torch 生成コンバータをご覧ください。

Kaggle Models で入手できる Gemma-2 2B を使用することをおすすめします。使用可能な他のモデルの詳細については、タスクの概要のモデルのセクションをご覧ください。

モデルを MediaPipe 形式に変換する

LLM 推論 API は、2 つのカテゴリのモデルと互換性があります。一部のモデルではモデル変換が必要です。次の表を使用して、モデルに必要なステップ方法を特定します。

モデル コンバージョンの方法 対応プラットフォーム ファイル形式
サポートされているモデル Gemma 2B、Gemma 7B、Gemma-2 2B、Phi-2、StableLM、Falcon MediaPipe Android、iOS、ウェブ .bin
その他の PyTorch モデル すべての PyTorch LLM モデル AI Edge Torch 生成ライブラリ Android、iOS .task

Gemma 2B、Gemma 7B、Gemma-2 2B の変換済み .bin ファイルは Kaggle でホストされています。これらのモデルは、LLM Inference API を使用して直接デプロイできます。他のモデルを変換する方法については、モデルの変換をご覧ください。

タスクを作成する

LLM Inference API タスクを作成するには、いずれかのイニシャライザを呼び出します。LlmInference(options:) イニシャライザは、構成オプションの値を設定します。

カスタマイズされた構成オプションで初期化された LLM 推論 API が不要な場合は、LlmInference(modelPath:) イニシャライザを使用して、デフォルト オプションで LLM 推論 API を作成できます。構成オプションの詳細については、構成の概要をご覧ください。

次のコードは、このタスクをビルドして構成する方法を示しています。

import MediaPipeTasksGenai

let modelPath = Bundle.main.path(forResource: "model",
                                      ofType: "bin")

let options = LlmInferenceOptions()
options.baseOptions.modelPath = modelPath
options.maxTokens = 1000
options.topk = 40
options.temperature = 0.8
options.randomSeed = 101

let llmInference = try LlmInference(options: options)

設定オプション

このタスクには、iOS アプリ用の次の構成オプションがあります。

オプション名 説明 値の範囲 デフォルト値
modelPath プロジェクト ディレクトリ内のモデルの保存場所へのパス。 経路 なし
maxTokens モデルが処理するトークンの最大数(入力トークン + 出力トークン)。 Integer 512
topk 生成の各ステップでモデルが考慮するトークンの数。予測を、最も高い確率を持つ上位 k 個のトークンに制限します。 Integer 40
temperature 生成時に導入されるランダム性の量。温度が高いほど、生成されるテキストの創造性が高まり、温度が低いほど、生成が予測しやすくなります。 浮動小数点数 0.8
randomSeed テキスト生成時に使用される乱数シード。 Integer 0
loraPath デバイス上のローカル LoRA モデルの絶対パス。注: これは GPU モデルにのみ対応しています。 経路 なし

データの準備

LLM Inference API はテキストデータで動作します。このタスクは、トークン化やテンソルの前処理など、データ入力の前処理を処理します。

すべての前処理は generateResponse(inputText:) 関数内で処理されます。事前に入力テキストを追加で前処理する必要はありません。

let inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

タスクを実行する

LLM Inference API を実行するには、generateResponse(inputText:) メソッドを使用します。LLM Inference API は、入力テキストの可能性のあるカテゴリを返します。

let result = try LlmInference.generateResponse(inputText: inputPrompt)

レスポンスをストリーミングするには、generateResponseAsync(inputText:) メソッドを使用します。

let resultStream =  LlmInference.generateResponseAsync(inputText: inputPrompt)

do {
  for try await partialResult in resultStream {
    print("\(partialResult)")
  }
  print("Done")
}
catch {
  print("Response error: '\(error)")
}

結果を処理して表示する

LLM Inference API は、生成されたレスポンス テキストを返します。

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

LoRA モデルのカスタマイズ

Mediapipe LLM 推論 API は、大規模言語モデルの Low-Rank Adaptation(LoRA)をサポートするように構成できます。ファインチューニングされた LoRA モデルを使用すると、デベロッパーは費用対効果の高いトレーニング プロセスで LLM の動作をカスタマイズできます。

LLM Inference API の LoRA サポートは、GPU バックエンドのすべての Gemma バリアントと Phi-2 モデルで機能します。LoRA 重みは、アテンション レイヤにのみ適用されます。この最初の実装は、今後の開発のための試験運用版 API として機能します。今後のアップデートでは、より多くのモデルとさまざまなタイプのレイヤをサポートする予定です。

LoRA モデルを準備する

HuggingFace の手順に沿って、サポートされているモデルタイプ(Gemma または Phi-2)を使用して、独自のデータセットでファインチューニングされた LoRA モデルをトレーニングします。Gemma-2 2BGemma 2BPhi-2 モデルは、HuggingFace で safetensors 形式で使用できます。LLM Inference API は、アテンション レイヤでの LoRA のみをサポートしているため、LoraConfig の作成時にアテンション レイヤのみを指定します。

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

テスト用に、HuggingFace で利用可能な LLM 推論 API に適合する、一般公開されているファインチューニング済みの LoRA モデルがあります。たとえば、Gemma-2B の場合は monsterapi/gemma-2b-lora-maths-orca-200k、Phi-2 の場合は lole25/phi-2-sft-ultrachat-lora です。

準備したデータセットでトレーニングしてモデルを保存すると、ファインチューニングされた LoRA モデルの重みを含む adapter_model.safetensors ファイルが作成されます。safetensors ファイルは、モデル変換で使用される LoRA チェックポイントです。

次のステップとして、MediaPipe Python パッケージを使用して、モデルの重みを TensorFlow Lite Flatbuffer に変換する必要があります。ConversionConfig には、ベースモデル オプションと追加の LoRA オプションを指定する必要があります。この API は GPU を使用した LoRA 推論のみをサポートしているため、バックエンドは 'gpu' に設定する必要があります。

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

コンバータは、ベースモデル用と LoRA モデル用の 2 つの TFLite FlatBuffer ファイルを出力します。

LoRA モデルの推論

ウェブ、Android、iOS の LLM 推論 API が更新され、LoRA モデル推論がサポートされるようになりました。

iOS は、初期化時に静的 LoRA をサポートしています。LoRA モデルを読み込むには、LoRA モデルパスとベース LLM を指定します。

import MediaPipeTasksGenai

let modelPath = Bundle.main.path(forResource: "model",
                                      ofType: "bin")
let loraPath= Bundle.main.path(forResource: "lora_model",
                                      ofType: "bin")
let options = LlmInferenceOptions()
options.modelPath = modelPath
options.maxTokens = 1000
options.topk = 40
options.temperature = 0.8
options.randomSeed = 101
options.loraPath = loraPath

let llmInference = try LlmInference(options: options)

LoRA で LLM 推論を実行するには、ベースモデルと同じ generateResponse() メソッドまたは generateResponseAsync() メソッドを使用します。