PyTorch を使用して Gemma を実行する

このガイドでは、PyTorch フレームワークを使用して Gemma を実行する方法について説明します。Gemma リリース 3 以降のモデルに画像データを使用してプロンプトを表示する方法も説明します。Gemma PyTorch の実装の詳細については、プロジェクト リポジトリの README をご覧ください。

セットアップ

以降のセクションでは、Kaggle からダウンロードするための Gemma モデルにアクセスする方法、認証変数の設定方法、依存関係のインストール方法、パッケージのインポート方法など、開発環境を設定する方法について説明します。

システム要件

この Gemma Pytorch ライブラリでは、Gemma モデルを実行するために GPU または TPU プロセッサが必要です。Gemma 1B、2B、4B サイズのモデルを実行するには、標準の Colab CPU Python ランタイムと T4 GPU Python ランタイムで十分です。他の GPU または TPU の高度なユースケースについては、Gemma PyTorch リポジトリの README をご覧ください。

Kaggle で Gemma にアクセスする

このチュートリアルを完了するには、まず Gemma の設定の手順に沿って、次の操作を行う必要があります。

  • kaggle.com で Gemma にアクセスします。
  • Gemma モデルを実行するのに十分なリソースがある Colab ランタイムを選択します。
  • Kaggle のユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら、次のセクションに進みます。ここでは、Colab 環境の環境変数を設定します。

環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを許可します。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

依存関係のインストール

pip install -q -U torch immutabledict sentencepiece

モデルの重みをダウンロードする

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '4b':
  CONFIG = '4b-v1'
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

モデルのトークン化ツールとチェックポイントのパスを設定します。

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

実行環境を構成する

以降のセクションでは、Gemma を実行するための PyTorch 環境を準備する方法について説明します。

PyTorch 実行環境を準備する

Gemma Pytorch リポジトリのクローンを作成して、PyTorch モデル実行環境を準備します。

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

モデル構成を設定する

モデルを実行する前に、Gemma バリアント、トークン化、量子化レベルなどの構成パラメータを設定する必要があります。

# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

デバイスのコンテキストを設定する

次のコードは、モデルの実行用のデバイス コンテキストを構成します。

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

モデルをインスタンス化して読み込む

重みとともにモデルを読み込み、リクエストの実行を準備します。

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

推論を実行する

以下に、チャットモードでの生成と複数のリクエストでの生成の例を示します。

指示に従って調整された Gemma モデルは、トレーニングと推論の両方で指示のチューニング例に追加情報をアノテーションする特定のフォーマッタでトレーニングされています。アノテーションは、(1)会話内の役割を示し、(2)会話内のターンを区切ります。

関連するアノテーション トークンは次のとおりです。

  • user: ユーザーのターン
  • model: モデルの回転
  • <start_of_turn>: ダイアログ ターンの開始
  • <start_of_image>: 画像データ入力のタグ
  • <end_of_turn><eos>: ダイアログ ターンの終了

詳細については、命令でチューニングされた Gemma モデルのプロンプト形式について [こちら](https://ai.google.dev/gemma/core/prompt-structure) をご覧ください。

テキストからテキストを生成する

次のコード スニペットは、マルチターンの会話でユーザーとモデルのチャット テンプレートを使用して、指示に基づいてチューニングされた Gemma モデルのプロンプトをフォーマットする方法を示しています。

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

画像からテキストを生成する

Gemma リリース 3 以降では、プロンプトで画像を使用できます。次の例は、プロンプトに画像データを含める方法を示しています。

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)

print(model.generate(
    [['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
    device=device,
    output_len=OUTPUT_LEN,
))

その他の情報

Pytorch で Gemma を使用する方法を学習したので、ai.google.dev/gemma で Gemma が実行できるその他の多くの機能を確認できます。以下の関連リソースもご覧ください。