本教程介绍了如何开始使用 KerasNLP 与 Gemma 交互。Gemma 是一系列先进的轻量级开放式模型,其开发采用了与 Gemini 模型相同的研究成果和技术。KerasNLP 是一系列在 Keras 中实现的自然语言处理 (NLP) 模型,可在 JAX、PyTorch 和 TensorFlow 上运行。
在本教程中,您将使用 Gemma 针对几个问题生成文本回答。如果您刚开始接触 Keras,不妨先阅读 Keras 使用入门,但这并非强制性要求。在本教程中,您将详细了解 Keras。
设置
Gemma 设置
如需完成本教程,您首先需要按照 Gemma 设置中的设置说明操作。Gemma 设置说明介绍了如何执行以下操作:
- 在 kaggle.com 上获取 Gemma 的访问权限。
- 选择具有足够资源来运行 Gemma 2B 模型的 Colab 运行时。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 Gemma 设置后,请继续下一部分,在其中为 Colab 环境设置环境变量。
设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安装依赖项
安装 Keras 和 KerasNLP。
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"
选择一个后端
Keras 是一个高级多框架深度学习 API,旨在实现简单易用。在 Keras 3 中,您可以选择后端:TensorFlow、JAX 或 PyTorch。在本教程中,这三种方式都可以。
import os
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
导入软件包
导入 Keras 和 KerasNLP。
import keras
import keras_nlp
创建模型
KerasNLP 提供了许多热门模型架构的实现。在本教程中,您将使用 GemmaCausalLM
(一种用于因果语言建模的端到端 Gemma 模型)创建模型。因果语言模型可根据之前的令牌预测下一个令牌。
使用 from_preset
方法创建模型:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
GemmaCausalLM.from_preset()
函数会根据预设的架构和权重实例化模型。在上面的代码中,字符串 "gemma2_2b_en"
指定了具有 20 亿个参数的预设 Gemma 2 2B 模型。我们还提供70 亿、90 亿和 270 亿个参数的 Gemma 模型。您可以在 Kaggle 上的 Model Variation 列表中找到 Gemma 模型的代码字符串。
使用 summary
获取有关模型的更多信息:
gemma_lm.summary()
如摘要所示,该模型有 26 亿个可训练参数。
生成文本
现在,该生成一些文本了!该模型具有一个 generate
方法,该方法基于提示生成文本。可选的 max_length
参数用于指定生成的序列的长度上限。
您可以使用 "what is keras in 3 bullet points?"
提示试用此功能。
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'
请使用其他提示再次调用 generate
。
gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'
如果您在 JAX 或 TensorFlow 后端上运行,您会发现第二个 generate
调用几乎会立即返回。这是因为,对于给定的批处理大小,对 generate
的每次调用和 max_length
都是使用 XLA 编译的。首次运行的费用较高,但后续运行的速度会快得多。
您还可以使用列表作为输入来提供批量提示:
gemma_lm.generate(
["what is keras in 3 bullet points?",
"The universe is"],
max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n', 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']
可选:尝试使用其他采样器
您可以通过在 compile()
上设置 sampler
参数来控制 GemmaCausalLM
的生成策略。默认情况下,系统会使用 "greedy"
采样。
作为一项实验,请尝试设置 "top_k"
策略:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'
虽然默认的贪心算法始终会选择概率最大的 token,但 top-K 算法会从概率最高的 K 个 token 中随机选择下一个 token。
您无需指定采样器,如果最后一个代码段对您的用例没有帮助,您可以忽略它。如需详细了解可用的抽样器,请参阅抽样器。
后续步骤
在本教程中,您学习了如何使用 KerasNLP 和 Gemma 生成文本。下面是有关后续学习内容的一些建议:
- 了解如何微调 Gemma 模型。
- 了解如何对 Gemma 模型执行分布式微调和推理。
- 了解 Gemma 与 Vertex AI 的集成
- 了解如何将 Gemma 模型与 Vertex AI 搭配使用。