使用 JAX 和 Flax 通过 RecurrentGemma 进行推断

前往 ai.google.dev 查看 在 Google Colab 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

本教程演示了如何使用 RecurrentGemma 2B Instruct 模型执行基本采样/推理,即使用 Google DeepMind 的 recurrentgemma,该库使用 JAX(一个高性能数值计算库)、Flax(基于 JAX 的神经网络库)、Orbax(一个基于 JAX Pieceizer 的训练库)和 Checkpoint2 等实用程序{detoken1}等工具编写而成。SentencePiece虽然 Flax 未直接用于此笔记本,但使用了 Flax 来创建 Gemma 和 RecurrentGemma(Griffin 模型)。

此笔记本可以在采用 T4 GPU 的 Google Colab 上运行(依次前往修改 > 笔记本设置 > 在硬件加速器下选择 T4 GPU)。

设置

以下部分介绍了准备笔记本以使用 RecurrentGemma 模型的步骤,包括模型访问、获取 API 密钥和配置笔记本运行时

为 Gemma 设置 Kaggle 访问权限

如需完成本教程,您首先需要按照类似于 Gemma 设置的设置说明进行操作,但有一些例外情况:

  • kaggle.com 上访问 RecurrentGemma(而非 Gemma)。
  • 请选择具有足够资源的 Colab 运行时来运行 RecurrentGemma 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 RecurrentGemma 的设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。当看到“是否授予访问权限?”的提示时,则同意提供私密访问。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装 recurrentgemma

此笔记本侧重于如何使用免费的 Colab GPU。要启用硬件加速,请依次点击修改 >笔记本设置 >选择 T4 GPU >保存

接下来,您需要从 github.com/google-deepmind/recurrentgemma 安装 Google DeepMind recurrentgemma 库。如果您收到有关“pip 的依赖项解析器”的错误,通常可以忽略它。

pip install git+https://github.com/google-deepmind/recurrentgemma.git

加载并准备 RecurrentGemma 模型

  1. 使用 kagglehub.model_download 加载 RecurrentGemma 模型,该方法接受三个参数:
  • handle:Kaggle 的模型句柄
  • path:(可选字符串)本地路径
  • force_download:(可选布尔值)强制重新下载模型
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. 检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于下载模型的主目录中,而模型权重则位于子目录中。例如:
  • tokenizer.model 文件将位于 /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 中。
  • 模型检查点将位于 /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it 中。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

执行采样/推理

  1. 使用 recurrentgemma.jax.load_parameters 方法加载 RecurrentGemma 模型检查点。设置为 "single_device"sharding 参数会在单个设备上加载所有模型参数。
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. 加载使用 sentencepiece.SentencePieceProcessor 构建的 RecurrentGemma 模型分词器:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. 如需从 RecurrentGemma 模型检查点自动加载正确配置,请使用 recurrentgemma.GriffinConfig.from_flax_params_or_variables。然后,使用 recurrentgemma.jax.Griffin 实例化 Griffin 模型。
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. 在 RecurrentGemma 模型检查点/权重和标记生成器之上,使用 recurrentgemma.jax.Sampler 创建 sampler
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt 中编写提示并执行推理。您可以调整 total_generation_steps(生成响应时执行的步骤数 - 此示例使用 50 保留主机内存)。
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

了解详情