使用 PyTorch 运行 Gemma

本指南介绍了如何使用 PyTorch 框架运行 Gemma,包括如何使用图片数据提示 Gemma 版本 3 及更高版本的模型。如需详细了解 Gemma PyTorch 实现,请参阅项目代码库的 README

设置

以下部分介绍了如何设置开发环境,包括如何获取对 Gemma 模型的访问权限以便从 Kaggle 下载、设置身份验证变量、安装依赖项和导入软件包。

系统要求

此 Gemma Pytorch 库需要 GPU 或 TPU 处理器才能运行 Gemma 模型。标准 Colab CPU Python 运行时和 T4 GPU Python 运行时足以运行 Gemma 1B、2B 和 4B 大小的模型。如需了解其他 GPU 或 TPU 的高级用例,请参阅 Gemma PyTorch 代码库中的 README

在 Kaggle 上访问 Gemma

如需完成本教程,您首先需要按照 Gemma 设置中的设置说明操作,其中介绍了如何执行以下操作:

  • kaggle.com 上获取 Gemma 访问权限。
  • 选择一个具有足够资源来运行 Gemma 模型的 Colab 运行时。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续下一部分,在其中为 Colab 环境设置环境变量。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。当系统提示“要授予访问权限吗?”时,同意提供 Secret 访问权限。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装依赖项

pip install -q -U torch immutabledict sentencepiece

下载模型权重

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '4b':
  CONFIG = '4b-v1'
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

为模型设置分词器和检查点路径。

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

配置运行环境

以下部分介绍了如何准备 PyTorch 环境以运行 Gemma。

准备 PyTorch 运行环境

克隆 Gemma Pytorch 代码库,准备 PyTorch 模型执行环境。

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

设置模型配置

在运行模型之前,您必须设置一些配置参数,包括 Gemma 变体、分词器和量化级别。

# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

配置设备情境

以下代码会配置设备上下文以运行模型:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

实例化并加载模型

加载模型及其权重,以准备运行请求。

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

运行推断

以下是聊天模式下生成和使用多个请求生成的示例。

指令调优型 Gemma 模型是使用特定的格式设置进行训练的,该格式设置会在训练和推理期间为指令调优示例添加额外信息。注释 (1) 表示对话中的角色,(2) 描述对话中的对话回合。

相关的注解令牌如下所示:

  • user:用户回合
  • model:模型转弯
  • <start_of_turn>:对话转换的开始
  • <start_of_image>:图片数据输入的标记
  • <end_of_turn><eos>:对话转换结束

如需了解详情,请参阅 [此处](https://ai.google.dev/gemma/core/prompt-structure),了解针对指令调优的 Gemma 模型的提示格式

使用文本生成文本

以下示例代码段演示了如何在多轮对话中使用用户和模型聊天模板为基于指令调整的 Gemma 模型设置提示格式。

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

使用图片生成文本

在 Gemma 版本 3 及更高版本中,您可以将图片与问题一起使用。以下示例展示了如何在问题中添加视觉数据。

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)

print(model.generate(
    [['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
    device=device,
    output_len=OUTPUT_LEN,
))

了解详情

现在,您已经了解了如何在 Pytorch 中使用 Gemma,接下来可以前往 ai.google.dev/gemma 探索 Gemma 的众多其他用途。另请参阅以下其他相关资源: