前往 ai.google.dev 查看 | 在 Google Colab 中运行 | 在 Vertex AI 中打开 | 在 GitHub 上查看源代码 |
本教程演示了如何使用 RecurrentGemma 2B Instruct 模型执行基本采样/推理,即使用 Google DeepMind 的 recurrentgemma
库,该库使用 JAX(一个高性能数值计算库)、Flax(基于 JAX 的神经网络库)、Orbax(一个基于 JAX Pieceizer 的训练库)和 Checkpoint2 等实用程序{detoken1}等工具编写而成。SentencePiece虽然 Flax 未直接用于此笔记本,但使用了 Flax 来创建 Gemma 和 RecurrentGemma(Griffin 模型)。
此笔记本可以在采用 T4 GPU 的 Google Colab 上运行(依次前往修改 > 笔记本设置 > 在硬件加速器下选择 T4 GPU)。
设置
以下部分介绍了准备笔记本以使用 RecurrentGemma 模型的步骤,包括模型访问、获取 API 密钥和配置笔记本运行时
为 Gemma 设置 Kaggle 访问权限
如需完成本教程,您首先需要按照类似于 Gemma 设置的设置说明进行操作,但有一些例外情况:
- 在 kaggle.com 上访问 RecurrentGemma(而非 Gemma)。
- 请选择具有足够资源的 Colab 运行时来运行 RecurrentGemma 模型。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 RecurrentGemma 的设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。
设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。当看到“是否授予访问权限?”的提示时,则同意提供私密访问。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安装 recurrentgemma
库
此笔记本侧重于如何使用免费的 Colab GPU。要启用硬件加速,请依次点击修改 >笔记本设置 >选择 T4 GPU >保存。
接下来,您需要从 github.com/google-deepmind/recurrentgemma
安装 Google DeepMind recurrentgemma
库。如果您收到有关“pip 的依赖项解析器”的错误,通常可以忽略它。
pip install git+https://github.com/google-deepmind/recurrentgemma.git
加载并准备 RecurrentGemma 模型
- 使用
kagglehub.model_download
加载 RecurrentGemma 模型,该方法接受三个参数:
handle
:Kaggle 的模型句柄path
:(可选字符串)本地路径force_download
:(可选布尔值)强制重新下载模型
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- 检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于下载模型的主目录中,而模型权重则位于子目录中。例如:
tokenizer.model
文件将位于/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
中。- 模型检查点将位于
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
中。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
执行采样/推理
- 使用
recurrentgemma.jax.load_parameters
方法加载 RecurrentGemma 模型检查点。设置为"single_device"
的sharding
参数会在单个设备上加载所有模型参数。
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- 加载使用
sentencepiece.SentencePieceProcessor
构建的 RecurrentGemma 模型分词器:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- 如需从 RecurrentGemma 模型检查点自动加载正确配置,请使用
recurrentgemma.GriffinConfig.from_flax_params_or_variables
。然后,使用recurrentgemma.jax.Griffin
实例化 Griffin 模型。
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- 在 RecurrentGemma 模型检查点/权重和标记生成器之上,使用
recurrentgemma.jax.Sampler
创建sampler
:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- 在
prompt
中编写提示并执行推理。您可以调整total_generation_steps
(生成响应时执行的步骤数 - 此示例使用50
保留主机内存)。
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
了解详情
- 您可以在 GitHub 上详细了解 Google DeepMind 的
recurrentgemma
库,该库包含您在本教程中使用的方法和模块的文档字符串,例如recurrentgemma.jax.load_parameters
、recurrentgemma.jax.Griffin
和recurrentgemma.jax.Sampler
。 - 以下库有自己的文档网站:core JAX、Flax 和 Orbax。
- 如需查看
sentencepiece
标记生成器/detokenizer 文档,请查看 Google 的sentencepiece
GitHub 代码库。 - 如需查看
kagglehub
文档,请参阅 Kaggle 的kagglehub
GitHub 代码库中的README.md
。 - 了解如何将 Gemma 模型与 Google Cloud Vertex AI 搭配使用。
- 查看 RecurrentGemma: Moving Past Transformer 的《Efficient Open Language Models》论文。
- 阅读 Griffin: Mixing Gated Linear Recurrences with “Local Attention for Efficient Language Models”这篇文章,详细了解 RecurrentGemma 使用的模型架构。