在 ai.google.dev 上查看 | 在 Google Colab 中运行 | 在 Vertex AI 中打开 | 查看 GitHub 上的源代码 |
概览
Gemma 是以 Google DeepMind Gemini 研究和技术为基础打造的轻量级、先进的开放大语言模型系列。本教程演示了如何使用 Google DeepMind 的 gemma
库(使用 JAX(高性能数值计算库)、Flax(基于 JAX 的神经网络库)、Orbax(一种基于 JAX 的库 - 用于训练 Sent0Pieizer 令牌等的 JAX 库)和SentencePiece虽然此笔记本中未直接使用 Flax,但用于创建 Gemma 的 Flax。
此笔记本可以在 Google Colab 上运行免费的 T4 GPU(依次转到修改 > 笔记本设置 > 在硬件加速器下,选择 T4 GPU)。
初始设置
1. 为 Gemma 设置 Kaggle 访问权限
要完成本教程,您首先需要按照 Gemma 设置中的设置说明进行操作,其中说明了如何执行以下操作:
- 通过 kaggle.com 访问 Gemma。
- 选择具有足够资源的 Colab 运行时来运行 Gemma 模型。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 Gemma 设置后,请继续下一部分,您将为 Colab 环境设置环境变量。
2. 设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。当系统提示“授予访问权限吗?”消息时,同意提供密钥访问权限。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. 安装 gemma
库
此笔记本侧重于使用免费的 Colab GPU。如需启用硬件加速,请依次点击修改 > 笔记本设置 > 选择 T4 GPU > 保存。
接下来,您需要从 github.com/google-deepmind/gemma
安装 Google DeepMind gemma
库。如果收到有关“pip 的依赖项解析器”的错误,通常可以忽略。
pip install -q git+https://github.com/google-deepmind/gemma.git
加载和准备 Gemma 模型
- 使用
kagglehub.model_download
加载 Gemma 模型,该模型采用三个参数:
handle
:Kaggle 中的模型句柄path
:(可选字符串)本地路径force_download
:(可选布尔值)强制重新下载模型
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download... 100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
- 检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于您下载模型的主目录中,而模型权重将位于一个子目录中。例如:
tokenizer.model
文件将位于/LOCAL/PATH/TO/gemma/flax/2b-it/2
中。- 模型检查点位于
/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it
中)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model
执行采样/推断
- 使用
gemma.params.load_and_format_params
方法加载 Gemma 模型检查点并设置其格式:
from gemma import params as params_lib
params = params_lib.load_and_format_params(CKPT_PATH)
- 加载使用
sentencepiece.SentencePieceProcessor
构建的 Gemma 分词器:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- 如需从 Gemma 模型检查点自动加载正确的配置,请使用
gemma.transformer.TransformerConfig
。cache_size
参数是 GemmaTransformer
缓存中时间步数。之后,使用gemma.transformer.Transformer
(继承自flax.linen.Module
)将 Gemma 模型实例化为transformer
。
from gemma import transformer as transformer_lib
transformer_config = transformer_lib.TransformerConfig.from_params(
params=params,
cache_size=1024
)
transformer = transformer_lib.Transformer(transformer_config)
- 在 Gemma 模型检查点/权重和标记生成器中使用
gemma.sampler.Sampler
创建一个sampler
:
from gemma import sampler as sampler_lib
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer'],
)
- 在
input_batch
中编写提示并执行推断。您可以调整total_generation_steps
(生成响应时执行的步骤数;此示例使用100
保留主机内存)。
prompt = [
"\n# What is the meaning of life?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=100,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt: # What is the meaning of life? Output: The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question. **Some common perspectives on the meaning of life include:** * **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce. * **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
- (可选)如果您已完成笔记本运行并想尝试其他提示,请运行此单元以释放内存。之后,您可以在第 3 步中再次实例化
sampler
,并在第 4 步中自定义并运行提示。
del sampler
了解详情
- 您可以详细了解 GitHub 上的 Google DeepMind 库,其中包含您在本教程中使用的模块的文档字符串,例如
gemma.params
、gemma.transformer
和gemma.sampler
。gemma
- 以下库都有自己的文档网站:core JAX、Flax 和 Orbax。
- 如需
sentencepiece
分词器/去标记生成器文档,请参阅 Google 的sentencepiece
GitHub 代码库。 - 如需获取
kagglehub
文档,请查看 Kaggle 的kagglehub
GitHub 代码库中的README.md
。 - 了解如何将 Gemma 模型与 Google Cloud Vertex AI 搭配使用。