前往 ai.google.dev 查看 | 在 Google Colab 中執行 | 在 Vertex AI 中開啟 | 前往 GitHub 查看原始碼 |
本教學課程示範如何使用 RecurrentGemma 2B Instruct 模型執行基本的取樣/推論:使用以 JAX (高效能數值運算程式庫)、Flax (例如 JAX 型類神經網路程式庫) 和 Orbax雖然這個筆記本中並未直接使用 Flax,但 Flax 是用來建立 Gemma 和 RecurrentGemma (Griffin 模型) 使用。
這個筆記本可以在 Google Colab 和 T4 GPU 上執行 (依序前往「編輯」>「筆記本設定」>「硬體加速器」下方選取「T4 GPU」)。
設定
以下各節說明如何準備筆記本以使用 RecurrentGemma 模型,包括模型存取、取得 API 金鑰,以及設定筆記本執行階段
設定 Gemma 的 Kaggle 存取權
如要完成本教學課程,請先按照類似 Gemma 設定的設定指示操作,但請留意以下例外情況:
- 透過 kaggle.com 存取 RecurrentGemma (而非 Gemma)。
- 選取資源充足的 Colab 執行階段來執行 RecurrentGemma 模型。
- 產生並設定 Kaggle 使用者名稱和 API 金鑰。
完成 RecurrentGemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。
設定環境變數
設定 KAGGLE_USERNAME
和 KAGGLE_KEY
的環境變數。系統顯示「授予存取權?」提示訊息時,訊息,同意提供密鑰存取權。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安裝 recurrentgemma
程式庫
這個筆記本著重介紹免費的 Colab GPU,如要啟用硬體加速,請按一下「編輯」>筆記本設定 >依序選取「T4 GPU」>按一下「儲存」。
接下來,您需要從 github.com/google-deepmind/recurrentgemma
安裝 Google DeepMind recurrentgemma
程式庫。如果收到有關「pip 依附元件解析器」的錯誤,通常可以忽略。
pip install git+https://github.com/google-deepmind/recurrentgemma.git
載入並準備 RecurrentGemma 模型
- 使用
kagglehub.model_download
載入 RecurrentGemma 模型,該模型會使用三個引數:
handle
:Kaggle 的模型控制代碼path
:(選用字串) 本機路徑force_download
:(選用布林值) 強制重新下載模型
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- 檢查模型權重的位置和符記化工具,然後設定路徑變數。符記化工具目錄會存放您下載模型的主要目錄,而模型權重則位於子目錄。例如:
tokenizer.model
檔案會在/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
中)。- 模型查核點位於
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
)。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
執行取樣/推論
- 使用
recurrentgemma.jax.load_parameters
方法載入 RecurrentGemma 模型查核點。sharding
引數設為"single_device"
會在單一裝置上載入所有模型參數。
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- 載入使用
sentencepiece.SentencePieceProcessor
建構的 RecurrentGemma 模型符記化工具:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- 如要自動從 RecurrentGemma 模型查核點載入正確的設定,請使用
recurrentgemma.GriffinConfig.from_flax_params_or_variables
。接著,使用recurrentgemma.jax.Griffin
將 Griffin 模型執行個體化。
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- 在 RecurrentGemma 模型查核點/權重和符記化工具之外,使用
recurrentgemma.jax.Sampler
建立sampler
:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- 在
prompt
中撰寫提示並執行推論。您可以調整total_generation_steps
(產生回應時執行的步驟數,此範例使用50
來保留主機記憶體)。
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
瞭解詳情
- 您可以進一步瞭解 GitHub 上的 Google DeepMind
recurrentgemma
程式庫,該程式庫包含您本教學課程使用的方法和模組文件字串,例如recurrentgemma.jax.load_parameters
、recurrentgemma.jax.Griffin
和recurrentgemma.jax.Sampler
。 - 下列程式庫都有專屬的說明文件網站:核心 JAX、Flax 以及 Orbax。
- 如需
sentencepiece
權杖化工具/解碼器說明文件,請前往 Google 的sentencepiece
GitHub 存放區。 - 如需
kagglehub
說明文件,請查看 Kagglekagglehub
GitHub 存放區中的README.md
。 - 瞭解如何搭配使用 Gemma 模型與 Google Cloud Vertex AI。
- 查看 RecurrentGemma: Moving Past Transformers 《The for Efficient Open Language Models》報告。
- 閱讀《Griffin:混合封閉式線性週期與 《Local Attention for Efficient Language Models》(高效語言模型的本機注意力) 報告:由 GoogleDeepMind 發布的模型,進一步瞭解 RecurrentGemma 使用的模型架構。