透過 KerasNLP 開始使用 Gemma

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Vertex AI 中開啟 前往 GitHub 查看原始碼

本教學課程說明如何透過 KerasNLP 開始使用 Gemma。Gemma 是一系列先進的開放式模型,與建立 Gemini 模型時使用的研究和技術相同。KerasNLP 是一組在 Keras 中實作的自然語言處理 (NLP) 模型,可在 JAX、PyTorch 和 TensorFlow 上執行。

在這個教學課程中,您將使用 Gemma 針對多個提示產生文字回應。如果您是 Keras 新手,建議在開始之前先參閱 開始使用 Keras,但這並非必要。在進行本教學課程的過程中,您將進一步瞭解 Keras。

設定

Gemma 設定

如要完成本教學課程,您必須先前往 Gemma 設定頁面完成設定。Gemma 設定操作說明會說明如何執行下列操作:

  • 前往 kaggle.com 存取 Gemma。
  • 選取資源充足且可執行的 Colab 執行階段 Gemma 2B 模型
  • 產生並設定 Kaggle 使用者名稱和 API 金鑰。

完成 Gemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安裝依附元件

安裝 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

選取後端

Keras 是高階的多架構深度學習 API,專為簡化使用而設計。Keras 3 可讓您選擇後端:TensorFlow、JAX 或 PyTorch。這三者都能進行本教學課程。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

匯入套件

匯入 Keras 和 KerasNLP。

import keras
import keras_nlp

建立模型

KerasNLP 可實作許多熱門的模型架構。在這個教學課程中,您會使用 GemmaCausalLM 建立模型,這是用於因果語言模型的端對端 Gemma 模型。因果語言模型會根據先前的符記預測下一個符記。

使用 from_preset 方法建立模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

GemmaCausalLM.from_preset() 函式會根據預設架構和權重將模型例項化。在上述程式碼中,字串 "gemma2_2b_en" 指定具有 20 億個參數的 Gemma 2 2B 模型。您也可以使用具有 7B、9B 和 27B 參數的 Gemma 模型。您可以在 Kaggle 的「模型變化版本」清單中,找到 Gemma 模型的程式碼字串。

使用 summary 取得模型的詳細資訊:

gemma_lm.summary()

如摘要所示,模型有 26 億個可訓練參數。

生成文字

現在該生成一些文字了!這個模型採用 generate 方法,可根據提示生成文字。選用的 max_length 引數會指定產生的序列長度上限。

輸入 "what is keras in 3 bullet points?" 提示試試看吧。

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

請改用其他提示再次呼叫 generate

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

如果是 JAX 或 TensorFlow 後端執行,會發現第二次 generate 呼叫幾乎會立即傳回結果。這是因為每次呼叫 generate 針對特定批量,max_length 則是使用 XLA 編譯。首次執行的費用較高,但後續執行作業會加快許多。

您也可以使用清單做為輸入內容,提供批次提示:

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

選擇性:改用其他取樣器

您可以藉由設定 compile() 上的 sampler 引數,控制 GemmaCausalLM 的產生策略。根據預設,系統會使用 "greedy" 的取樣功能。

您可以嘗試設定 "top_k" 策略:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

預設的貪婪演算法一律會挑選機率最高的符記,「Top-K」演算法會從「前 K 高」機率的符記中隨機挑選下一個符記。

您不必指定取樣器,如果最後的程式碼片段對您的用途沒有幫助,可以忽略最後一個程式碼片段。如要進一步瞭解可用的取樣器,請參閱取樣器

後續步驟

在本教學課程中,您已瞭解如何使用 KerasNLP 和 Gemma 來產生文字。您也可以參考下列建議,瞭解接下來該怎麼做: