使用 JAX 和 Flax 透過 RecurrentGemma 推論

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Vertex AI 中開啟 前往 GitHub 查看原始碼

本教學課程示範如何使用 RecurrentGemma 2B Instruct 模型執行基本的取樣/推論:使用以 JAX (高效能數值運算程式庫)、Flax (例如 JAX 型類神經網路程式庫) 和 Orbax雖然這個筆記本中並未直接使用 Flax,但 Flax 是用來建立 Gemma 和 RecurrentGemma (Griffin 模型) 使用。

這個筆記本可以在 Google Colab 和 T4 GPU 上執行 (依序前往「編輯」>「筆記本設定」>「硬體加速器」下方選取「T4 GPU」)。

設定

以下各節說明如何準備筆記本以使用 RecurrentGemma 模型,包括模型存取、取得 API 金鑰,以及設定筆記本執行階段

設定 Gemma 的 Kaggle 存取權

如要完成本教學課程,請先按照類似 Gemma 設定的設定指示操作,但請留意以下例外情況:

  • 透過 kaggle.com 存取 RecurrentGemma (而非 Gemma)。
  • 選取資源充足的 Colab 執行階段來執行 RecurrentGemma 模型。
  • 產生並設定 Kaggle 使用者名稱和 API 金鑰。

完成 RecurrentGemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。系統顯示「授予存取權?」提示訊息時,訊息,同意提供密鑰存取權。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安裝 recurrentgemma 程式庫

這個筆記本著重介紹免費的 Colab GPU,如要啟用硬體加速,請按一下「編輯」>筆記本設定 >依序選取「T4 GPU」>按一下「儲存」

接下來,您需要從 github.com/google-deepmind/recurrentgemma 安裝 Google DeepMind recurrentgemma 程式庫。如果收到有關「pip 依附元件解析器」的錯誤,通常可以忽略。

pip install git+https://github.com/google-deepmind/recurrentgemma.git

載入並準備 RecurrentGemma 模型

  1. 使用 kagglehub.model_download 載入 RecurrentGemma 模型,該模型會使用三個引數:
  • handle:Kaggle 的模型控制代碼
  • path:(選用字串) 本機路徑
  • force_download:(選用布林值) 強制重新下載模型
,瞭解如何調查及移除這項存取權。
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. 檢查模型權重的位置和符記化工具,然後設定路徑變數。符記化工具目錄會存放您下載模型的主要目錄,而模型權重則位於子目錄。例如:
  • tokenizer.model 檔案會在 /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 中)。
  • 模型查核點位於 /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it)。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

執行取樣/推論

  1. 使用 recurrentgemma.jax.load_parameters 方法載入 RecurrentGemma 模型查核點。sharding 引數設為 "single_device" 會在單一裝置上載入所有模型參數。
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. 載入使用 sentencepiece.SentencePieceProcessor 建構的 RecurrentGemma 模型符記化工具:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. 如要自動從 RecurrentGemma 模型查核點載入正確的設定,請使用 recurrentgemma.GriffinConfig.from_flax_params_or_variables。接著,使用 recurrentgemma.jax.GriffinGriffin 模型執行個體化。
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. 在 RecurrentGemma 模型查核點/權重和符記化工具之外,使用 recurrentgemma.jax.Sampler 建立 sampler
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt 中撰寫提示並執行推論。您可以調整 total_generation_steps (產生回應時執行的步驟數,此範例使用 50 來保留主機記憶體)。
,瞭解如何調查及移除這項存取權。
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

瞭解詳情