使用 Keras 使用 Gemma 進行推論

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Vertex AI 中開啟 前往 GitHub 查看原始碼

本教學課程說明如何將 Gemma 與 KerasNLP 搭配使用,執行推論及產生文字。Gemma 是一系列先進的開放式模型,與建立 Gemini 模型時使用的研究和技術相同。KerasNLP 是一組在 Keras 中實作的自然語言處理 (NLP) 模型,可在 JAX、PyTorch 和 TensorFlow 上執行。

在這個教學課程中,您將使用 Gemma 針對多個提示產生文字回應。如果您是 Keras 新手,建議在開始之前先參閱 開始使用 Keras,但這並非必要。在進行本教學課程的過程中,您將進一步瞭解 Keras。

設定

Gemma 設定

如要完成本教學課程,您必須先前往 Gemma 設定頁面完成設定。Gemma 設定操作說明會說明如何執行下列操作:

  • 前往 kaggle.com 存取 Gemma。
  • 請選取具有足夠資源來執行 Gemma 2B 模型的 Colab 執行階段。
  • 產生並設定 Kaggle 使用者名稱和 API 金鑰。

完成 Gemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安裝依附元件

安裝 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

選取後端

Keras 是高階的多架構深度學習 API,專為簡化使用而設計。Keras 3 可讓您選擇後端:TensorFlow、JAX 或 PyTorch。這三者都能進行本教學課程。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

匯入套件

匯入 Keras 和 KerasNLP。

import keras
import keras_nlp

建立模型

KerasNLP 可實作許多熱門的模型架構。在這個教學課程中,您會使用 GemmaCausalLM 建立模型,這是用於因果語言模型的端對端 Gemma 模型。因果語言模型會根據先前的符記預測下一個符記。

使用 from_preset 方法建立模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

GemmaCausalLM.from_preset() 函式會根據預設架構和權重將模型例項化。在上述程式碼中,字串 "gemma_2b_en" 指定具有 20 億個參數的預設 Gemma 2B 模型。您也可以使用具有 7B、9B 和 27B 參數的 Gemma 模型。您可以在 kaggle.com 的「模型變化版本」清單中,找到 Gemma 模型的程式碼字串。

使用 summary 取得模型的詳細資訊:

gemma_lm.summary()

如摘要所示,模型有 25 億個可訓練參數。

生成文字

現在該生成一些文字了!這個模型採用 generate 方法,可根據提示生成文字。選用的 max_length 引數會指定產生的序列長度上限。

輸入 "What is the meaning of life?" 提示試試看吧。

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

請改用其他提示再次呼叫 generate

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

如果是 JAX 或 TensorFlow 後端執行,會發現第二次 generate 呼叫幾乎會立即傳回結果。這是因為每次呼叫 generate 針對特定批量,max_length 則是使用 XLA 編譯。首次執行的費用較高,但後續執行作業會加快許多。

您也可以使用清單做為輸入內容,提供批次提示:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

選擇性:改用其他取樣器

您可以藉由設定 compile() 上的 sampler 引數,控制 GemmaCausalLM 的產生策略。根據預設,系統會使用 "greedy" 的取樣功能。

您可以嘗試設定 "top_k" 策略:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

預設的貪婪演算法一律會挑選機率最高的符記,「Top-K」演算法會從「前 K 高」機率的符記中隨機挑選下一個符記。

您不必指定取樣器,如果最後的程式碼片段對您的用途沒有幫助,可以忽略最後一個程式碼片段。如要進一步瞭解可用的取樣器,請參閱取樣器

後續步驟

在本教學課程中,您已瞭解如何使用 KerasNLP 和 Gemma 來產生文字。您也可以參考下列建議,瞭解接下來該怎麼做: