通过 KerasNLP 开始使用 Gemma

前往 ai.google.dev 查看 在 Google Colab 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

本教程介绍了如何使用 KerasNLP 开始使用 Gemma。Gemma 是一系列先进的轻量级开放模型,采用与 Gemini 模型相同的研究和技术构建而成。KerasNLP 是在 Keras 中实现的一系列自然语言处理 (NLP) 模型,可在 JAX、PyTorch 和 TensorFlow 上运行。

在本教程中,您将使用 Gemma 针对若干提示生成文本回复。如果您刚开始接触 Keras,则可能需要在开始之前阅读 Keras 使用入门,但您不必这样做。在学习本教程时,您将详细了解 Keras。

设置

Gemma 设置

要完成本教程,首先需要按照 Gemma 设置中的说明完成设置。Gemma 设置说明介绍了如何执行以下操作:

  • 在 kaggle.com 上访问 Gemma。
  • 选择具有足够资源的 Colab 运行时来运行 Gemma 2B 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装依赖项

安装 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

选择一个后端

Keras 是一个高级的多框架深度学习 API,旨在实现简洁易用。Keras 3 支持选择后端:TensorFlow、JAX 或 PyTorch。这三个选项都适用于本教程。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

导入软件包

导入 Keras 和 KerasNLP。

import keras
import keras_nlp

创建模型

KerasNLP 提供了许多热门模型架构的实现。在本教程中,您将使用 GemmaCausalLM 创建一个模型,这是一个用于因果语言建模的端到端 Gemma 模型。因果语言模型会根据上一个词元预测下一个词元。

使用 from_preset 方法创建模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

GemmaCausalLM.from_preset() 函数会根据预设架构和权重对模型进行实例化。在上面的代码中,字符串 "gemma2_2b_en" 指定了预设的 Gemma 2 2B 模型,其中包含 20 亿个参数。还可以使用具有 7B、9B 和 27B 参数的 Gemma 模型。您可以在 Kaggle 上的模型变体详情中找到 Gemma 模型的代码字符串。

使用 summary 获取有关模型的更多信息:

gemma_lm.summary()

从摘要中可以看出,该模型有 26 亿个可训练参数。

生成文本

现在,可以生成一些文本了!模型具有 generate 方法,用于根据提示生成文本。可选的 max_length 参数指定生成的序列的最大长度。

请尝试使用 "what is keras in 3 bullet points?" 提示符。

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

请使用其他提示再次尝试调用 generate

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

如果您在 JAX 或 TensorFlow 后端上运行,您会注意到第二个 generate 调用几乎立即返回。这是因为,针对给定批次大小和 max_length 的每次调用都是使用 XLA 进行编译的。generate首次运行成本高昂,但后续运行速度要快得多。

您还可以使用列表作为输入提供批量提示:

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

可选:尝试使用其他采样器

您可以通过在 compile() 上设置 sampler 参数来控制 GemmaCausalLM 的生成策略。默认情况下,将使用 "greedy" 抽样。

请尝试设置 "top_k" 策略:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

默认的贪心算法始终会选择概率最大的词元,而 Top-K 算法则会从概率最高的 K 词元中随机选择下一个词元。

您无需指定采样器,并且如果最后一个代码段对您的用例没有帮助,则可以忽略它。如需详细了解可用的采样器,请参阅采样器

后续步骤

在本教程中,您学习了如何使用 KerasNLP 和 Gemma 生成文本。下面提供了一些有关后续学习内容的建议: