透過 KerasNLP 開始使用 Gemma

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Vertex AI 中開啟 在 GitHub 上查看原始碼

本教學課程說明如何透過 KerasNLP 開始使用 Gemma。Gemma 是一系列先進的輕量開放模型,採用與建立 Gemini 模型相同的研究和技術打造而成。KerasNLP 是一組在 Keras 中實作的自然語言處理 (NLP) 模型,可在 JAX、PyTorch 和 TensorFlow 上執行。

在這個教學課程中,您將使用 Gemma 產生多則提示的文字回應。如果您是 Keras 新手,建議您先閱讀「開始使用 Keras」一文,然後再進行後續步驟。逐步進行本教學課程時,將會進一步瞭解 Keras。

設定

Gemma 設定

如要完成本教學課程,您必須先在 Gemma 設定中完成設定。Gemma 設定操作說明會說明如何執行下列操作:

  • 在 kaggle.com 上存取 Gemma。
  • 請選取具備足夠資源的 Colab 執行階段,才能執行 Gemma 2B 模型。
  • 產生並設定 Kaggle 使用者名稱與 API 金鑰。

完成 Gemma 設定後,請前往下一節,為 Colab 環境設定環境變數。

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安裝依附元件

安裝 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

選取後端

Keras 是一種高階的多架構深度學習 API,專為簡單易用而設計,Keras 3 可讓您選擇後端:TensorFlow、JAX 或 PyTorch。這三項指標都適用於本教學課程。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

匯入套件

匯入 Keras 和 KerasNLP。

import keras
import keras_nlp

建立模型

KerasNLP 提供許多熱門模型架構的實作。在這個教學課程中,您會使用 GemmaCausalLM 建立模型,這是用於因果語言模型的端對端 Gemma 模型。因果語言模型會根據先前的符記預測下一個符記。

使用 from_preset 方法建立模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

from_preset 會從預設架構和權重將模型例項化。上述程式碼中的字串 "gemma_2b_en" 會指定預設架構:包含 20 億個參數的 Gemma 模型。

使用 summary 取得模型的詳細資訊:

gemma_lm.summary()

如摘要所示,模型有 25 億個可訓練參數。

生成文字

現在可以開始生成文字了!模型有 generate 方法,可根據提示產生文字。選用的 max_length 引數會指定所產生序列的長度上限。

試試看搭配 "What is the meaning of life?" 提示即可試用。

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

請改用其他提示再次呼叫 generate

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

如果您在 JAX 或 TensorFlow 後端上執行,則第二個 generate 呼叫幾乎會立即傳回。這是因為特定批量大小每次呼叫 generatemax_length 都會使用 XLA 編譯。第一次執行的費用高昂,但後續執行作業的執行速度會更快。

您也可以使用清單做為輸入方式,提供批次提示:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

選用:改用其他取樣器

您可在 compile() 上設定 sampler 引數,控制 GemmaCausalLM 的產生策略。根據預設,系統將使用 "greedy" 取樣。

對於實驗,您可以嘗試設定 "top_k" 策略:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

雖然預設的貪婪演算法一律會選擇機率最高的符記,「Top-K」演算法會隨機從機率最高的 K 符記中挑選下一個符記。

您無須指定取樣器,如果最後的程式碼片段對您的用途沒有幫助,可以忽略最後一個程式碼片段。如要進一步瞭解可用的取樣器,請參閱取樣器

後續步驟

在這個教學課程中,您已學會如何使用 KerasNLP 和 Gemma 產生文字。歡迎參考下列建議,瞭解後續發展: