在 ai.google.dev 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 |
以下是关于在 PyTorch 中运行 Gemma 推理的快速演示。 如需了解详情,请点击此处查看官方 PyTorch 实现的 GitHub 代码库。
请注意:
- 免费的 Colab CPU Python 运行时和 T4 GPU Python 运行时足以运行 Gemma 2B 模型和 70 亿的 int8 量化模型。
- 如需了解其他 GPU 或 TPU 的高级用例,请参阅官方代码库中的 README.md。
1. 为 Gemma 设置 Kaggle 访问权限
要完成本教程,您首先需要按照 Gemma 设置中的设置说明进行操作,了解如何执行以下操作:
- 在 kaggle.com 上获取 Gemma 访问权限。
- 选择具有足够资源来运行 Gemma 模型的 Colab 运行时。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 Gemma 设置后,请继续下一部分,在其中为 Colab 环境设置环境变量。
2. 设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。当系统显示“要授予访问权限吗?”消息时,请同意提供 Secret 访问权限。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安装依赖项
pip install -q -U torch immutabledict sentencepiece
下载模型权重
# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '2b':
CONFIG = '2b-v2'
import os
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
下载模型实现
# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
设置模型
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT
# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()
运行推断
以下示例展示了如何在聊天模式下生成内容以及如何使用多个请求生成内容。
指令调优型 Gemma 模型是使用特定的格式设置进行训练的,该格式设置会在训练和推理期间为指令调优示例添加额外信息注释。注释 (1) 表示对话中的角色,(2) 描述对话中的对话回合。
相关的注释令牌如下所示:
user
:用户回合model
:模型转弯<start_of_turn>
:对话转换的开头<end_of_turn><eos>
:对话结束
如需了解详情,请点击此处,了解针对指令调优的 Gemma 模型的提示格式设置。
以下代码段示例演示了如何在多轮对话中使用用户和模型聊天模板为指令调优的 Gemma 模型设置提示格式。
# Generate with one request in chat mode
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=128,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
了解详情
现在,您已经了解了如何在 Pytorch 中使用 Gemma,接下来可以前往 ai.google.dev/gemma 探索 Gemma 的众多其他用途。另请参阅以下其他相关资源: