使用 JAX 和 Flax 使用 Gemma 进行推断

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Vertex AI 中打开 查看 GitHub 上的源代码

概览

Gemma 是以 Google DeepMind Gemini 研究和技术为基础打造的轻量级、先进的开放大语言模型系列。本教程演示了如何使用 Google DeepMind 的 gemma(使用 JAX(高性能数值计算库)、Flax(基于 JAX 的神经网络库)、Orbax(一种基于 JAX 的库 - 用于训练 Sent0Pieizer 令牌等的 JAX 库)和SentencePiece虽然此笔记本中未直接使用 Flax,但用于创建 Gemma 的 Flax。

此笔记本可以在 Google Colab 上运行免费的 T4 GPU(依次转到修改 > 笔记本设置 > 在硬件加速器下,选择 T4 GPU)。

初始设置

1. 为 Gemma 设置 Kaggle 访问权限

要完成本教程,您首先需要按照 Gemma 设置中的设置说明进行操作,其中说明了如何执行以下操作:

  • 通过 kaggle.com 访问 Gemma。
  • 选择具有足够资源的 Colab 运行时来运行 Gemma 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续下一部分,您将为 Colab 环境设置环境变量。

2. 设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。当系统提示“授予访问权限吗?”消息时,同意提供密钥访问权限。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. 安装 gemma

此笔记本侧重于使用免费的 Colab GPU。如需启用硬件加速,请依次点击修改 > 笔记本设置 > 选择 T4 GPU > 保存

接下来,您需要从 github.com/google-deepmind/gemma 安装 Google DeepMind gemma 库。如果收到有关“pip 的依赖项解析器”的错误,通常可以忽略。

pip install -q git+https://github.com/google-deepmind/gemma.git

加载和准备 Gemma 模型

  1. 使用 kagglehub.model_download 加载 Gemma 模型,该模型采用三个参数:
  • handle:Kaggle 中的模型句柄
  • path:(可选字符串)本地路径
  • force_download:(可选布尔值)强制重新下载模型
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. 检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于您下载模型的主目录中,而模型权重将位于一个子目录中。例如:
  • tokenizer.model 文件将位于 /LOCAL/PATH/TO/gemma/flax/2b-it/2 中。
  • 模型检查点位于 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it 中)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

执行采样/推断

  1. 使用 gemma.params.load_and_format_params 方法加载 Gemma 模型检查点并设置其格式:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. 加载使用 sentencepiece.SentencePieceProcessor 构建的 Gemma 分词器:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. 如需从 Gemma 模型检查点自动加载正确的配置,请使用 gemma.transformer.TransformerConfigcache_size 参数是 Gemma Transformer 缓存中时间步数。之后,使用 gemma.transformer.Transformer(继承自 flax.linen.Module)将 Gemma 模型实例化为 transformer
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. 在 Gemma 模型检查点/权重和标记生成器中使用 gemma.sampler.Sampler 创建一个 sampler
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. input_batch 中编写提示并执行推断。您可以调整 total_generation_steps(生成响应时执行的步骤数;此示例使用 100 保留主机内存)。
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (可选)如果您已完成笔记本运行并想尝试其他提示,请运行此单元以释放内存。之后,您可以在第 3 步中再次实例化 sampler,并在第 4 步中自定义并运行提示。
del sampler

了解详情