通过 KerasNLP 开始使用 Gemma

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Vertex AI 中打开 查看 GitHub 上的源代码

本教程介绍了如何使用 KerasNLP 开始使用 Gemma。Gemma 是轻量级、先进的开放模型系列,采用与创建 Gemini 模型相同的研究和技术构建而成。KerasNLP 是在 Keras 中实现的自然语言处理 (NLP) 模型的集合,并且可在 JAX、PyTorch 和 TensorFlow 上运行。

在本教程中,您将使用 Gemma 为多个提示生成文本回复。如果您刚开始接触 Keras,可能会在开始之前阅读 Keras 使用入门,但并非必须要这样做。您将在学习本教程的过程中详细了解 Keras。

初始设置

Gemma 设置

要完成本教程,您首先需要完成 Gemma 设置中的设置说明。Gemma 设置说明介绍了如何执行以下操作:

  • 在 kaggle.com 上访问 Gemma。
  • 选择具有足够资源的 Colab 运行时来运行 Gemma 2B 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续下一部分,您将为 Colab 环境设置环境变量。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装依赖项

安装 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

选择一个后端

Keras 是一个高级多框架深度学习 API,旨在实现简洁性和易用性。Keras 3 可让您选择后端:TensorFlow、JAX 或 PyTorch。所有这三个选项都适用于本教程。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

导入软件包

导入 Keras 和 KerasNLP。

import keras
import keras_nlp

创建模型

KerasNLP 提供了许多热门模型架构的实现。在本教程中,您将使用 GemmaCausalLM 创建一个模型,这是一种用于因果语言建模的端到端 Gemma 模型。因果语言模型会根据上一个词元预测下一个词元。

使用 from_preset 方法创建模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

from_preset 会根据预设的架构和权重实例化模型。在上面的代码中,字符串 "gemma_2b_en" 指定了预设架构:具有 20 亿个参数的 Gemma 模型。

使用 summary 获取有关模型的更多信息:

gemma_lm.summary()

您可以从摘要中看到,该模型有 25 亿个可训练参数。

生成文本

现在该生成一些文本了!该模型具有 generate 方法,可根据提示生成文本。可选的 max_length 参数指定所生成序列的最大长度。

您可以尝试输入 "What is the meaning of life?" 提示。

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

请使用其他提示再次尝试调用 generate

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

如果您在 JAX 或 TensorFlow 后端上运行,会注意到第二个 generate 调用几乎会立即返回结果。这是因为对于给定的批次大小和 max_length 每次调用 generate 都是使用 XLA 进行编译的。首次运行的开销很高,但后续的运行速度要快得多。

您还可以使用列表作为输入来提供批量提示:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

可选:尝试使用其他采样器

您可以通过在 compile() 上设置 sampler 参数来控制 GemmaCausalLM 的生成策略。默认情况下,系统将使用 "greedy" 抽样。

作为实验,请尝试设置 "top_k" 策略:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

虽然默认的贪心算法始终选择概率最高的词元,但 Top-K 算法会从概率最高的前 K 个词元中随机选择下一个词元。

您不必指定采样器,如果最后一个代码段对您的用例没有帮助,您可以忽略它。如需详细了解可用的采样器,请参阅采样器

后续步骤

在本教程中,您学习了如何使用 KerasNLP 和 Gemma 生成文本。以下是一些关于接下来学习内容的建议: