使用 PyTorch 執行 Gemma

在 ai.google.dev 上查看 在 Google Colab 中執行 在 GitHub 上查看來源

本指南說明如何使用 PyTorch 架構執行 Gemma,包括如何使用圖像資料提示 Gemma 3 以上版本的模型。如要進一步瞭解 Gemma PyTorch 實作方式,請參閱專案存放區的 README

設定

以下各節將說明如何設定開發環境,包括如何取得 Gemma 模型的存取權,以便從 Kaggle 下載模型、設定驗證變數、安裝依附元件,以及匯入套件。

系統需求

這個 Gemma Pytorch 程式庫需要 GPU 或 TPU 處理器才能執行 Gemma 模型。標準 Colab CPU Python 執行階段和 T4 GPU Python 執行階段,足以執行 Gemma 1B、2B 和 4B 大小的模型。如要瞭解其他 GPU 或 TPU 的進階用途,請參閱 Gemma PyTorch 存放區中的 README

在 Kaggle 上使用 Gemma

如要完成本教學課程,您必須先按照 Gemma 設定中的說明操作,瞭解如何執行下列操作:

  • 前往 kaggle.com 取得 Gemma 存取權。
  • 選取有足夠資源可執行 Gemma 模型的 Colab 執行階段。
  • 產生及設定 Kaggle 使用者名稱和 API 金鑰。

完成 Gemma 設定後,請繼續閱讀下一節,瞭解如何設定 Colab 環境的環境變數。

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。當系統提示「要授予存取權嗎?」時,請同意提供密鑰存取權。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安裝依附元件

pip install -q -U torch immutabledict sentencepiece

下載模型權重

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '4b':
  CONFIG = '4b-v1'
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

設定模型的剖析器和檢查點路徑。

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

設定執行環境

以下各節說明如何準備執行 Gemma 的 PyTorch 環境。

準備 PyTorch 執行環境

複製 Gemma Pytorch 存放區,準備 PyTorch 模型執行環境。

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

設定模型設定

執行模型前,您必須設定一些設定參數,包括 Gemma 變數、符號產生器和量化層級。

# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

設定裝置內容

下列程式碼會設定裝置執行模型的內容:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

例項化及載入模型

載入模型及其權重,準備執行要求。

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

執行推論

以下是使用對話模式和多個要求產生內容的範例。

使用特定格式化工具訓練指令調整 Gemma 模型,可在訓練和推論期間,為指令調整範例加上額外資訊的註解。註解 (1) 會指出對話中的角色,並 (2) 標示對話中的輪流發言。

相關的註解符記如下:

  • user:使用者回合
  • model:模型轉向
  • <start_of_turn>:對話輪替的開頭
  • <start_of_image>:圖片資料輸入標記
  • <end_of_turn><eos>:對話輪替結束

如需更多資訊,請參閱 [這篇文章](https://ai.google.dev/gemma/core/prompt-structure),瞭解如何為指令調整的 Gemma 模型設定提示格式。

使用文字生成文字

以下是範例程式碼片段,說明如何在多回合對話中,使用使用者和模型的即時通訊範本,為指令調整的 Gemma 模型設定提示格式。

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

生成含圖片的文字

使用 Gemma 3 以上版本時,您可以將圖片與提示搭配使用。以下範例說明如何在提示中加入視覺資料。

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)

print(model.generate(
    [['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
    device=device,
    output_len=OUTPUT_LEN,
))

瞭解詳情

您現在已瞭解如何在 Pytorch 中使用 Gemma,可以前往 ai.google.dev/gemma 探索 Gemma 的其他用途。請參閱其他相關資源: