本教學課程說明如何使用 KerasNLP 開始使用 Gemma。Gemma 是一系列先進的輕量級開放式模型,採用與建立 Gemini 模型時相同的研究成果和技術。KerasNLP 是一系列以 Keras 實作的自然語言處理 (NLP) 模型,可在 JAX、PyTorch 和 TensorFlow 上執行。
在本教學課程中,您將使用 Gemma 針對幾個提示產生文字回覆。如果您是 Keras 新手,建議您先閱讀「開始使用 Keras」一文,但這並非必要。您將在本教學課程中進一步瞭解 Keras。
設定
Gemma 設定
如要完成本教學課程,您必須先完成 Gemma 設定的設定說明。Gemma 設定說明會顯示如何執行下列操作:
- 前往 kaggle.com 取得 Gemma 的存取權。
- 選取具有足夠資源的 Colab 執行階段,以便執行 Gemma 2B 模型。
- 產生及設定 Kaggle 使用者名稱和 API 金鑰。
完成 Gemma 設定後,請繼續閱讀下一節,瞭解如何設定 Colab 環境的環境變數。
設定環境變數
設定 KAGGLE_USERNAME
和 KAGGLE_KEY
的環境變數。
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安裝依附元件
安裝 Keras 和 KerasNLP。
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"
選取後端
Keras 是高階的多架構深度學習 API,旨在提供簡單易用的服務。Keras 3 可讓您選擇後端:TensorFlow、JAX 或 PyTorch。這三種方法皆適用於本教學課程。
import os
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
匯入套件
匯入 Keras 和 KerasNLP。
import keras
import keras_nlp
建立模型
KerasNLP 提供許多熱門模型架構的實作方式。在本教學課程中,您將使用 GemmaCausalLM
建立模型,這是用於因果語言建模的端對端 Gemma 模型。因果語言模型會根據先前的符記預測下一個符記。
使用 from_preset
方法建立模型:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
GemmaCausalLM.from_preset()
函式會根據預先設定的架構和權重,將模型例項化。在上方程式碼中,字串 "gemma2_2b_en"
會指定 Gemma 2 2B 模型的預設值,其中包含 20 億個參數。您也可以使用含有 7B、9B 和 27B 參數的 Gemma 模型。您可以在 Kaggle 的模型變化版本清單中,找到 Gemma 模型的程式碼字串。
使用 summary
可進一步瞭解模型:
gemma_lm.summary()
如摘要所示,模型有 26 億個可訓練參數。
生成文字
接下來,我們要產生一些文字!這個模型具有 generate
方法,可根據提示產生文字。選用的 max_length
引數會指定產生序列的最大長度。
請使用提示 "what is keras in 3 bullet points?"
試試看。
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'
請使用其他提示再試一次。generate
gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'
如果您在 JAX 或 TensorFlow 後端上執行,您會發現第二個 generate
呼叫幾乎立即傳回。這是因為每個對 generate
的呼叫 (針對特定批次大小) 和 max_length
都會使用 XLA 編譯。首次執行的成本較高,但後續執行的速度會快得多。
您也可以使用清單做為輸入內容,提供批次提示:
gemma_lm.generate(
["what is keras in 3 bullet points?",
"The universe is"],
max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n', 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']
選用:嘗試使用其他取樣器
您可以設定 compile()
的 sampler
引數,藉此控制 GemmaCausalLM
的產生策略。根據預設,系統會使用 "greedy"
取樣。
您可以嘗試設定 "top_k"
策略做為實驗:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'
雖然預設的貪婪演算法一律會選取機率最高的符記,但前 K 大演算法會從機率最高的符記中隨機選取下一個符記。
您不必指定取樣器,如果最後一個程式碼片段對您的用途沒有幫助,您可以忽略該片段。如要進一步瞭解可用的取樣器,請參閱「取樣器」。
後續步驟
在本教學課程中,您已瞭解如何使用 KerasNLP 和 Gemma 產生文字。以下提供一些建議,供您參考下一個學習主題:
- 瞭解如何微調 Gemma 模型。
- 瞭解如何對 Gemma 模型執行分散式微調和推論。
- 瞭解 Gemma 與 Vertex AI 的整合
- 瞭解如何搭配 Vertex AI 使用 Gemma 模型。