使用 JAX 和 Flax 透過 Gemma 進行推論

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Vertex AI 中開啟 在 GitHub 上查看原始碼

總覽

Gemma 是一系列最先進的輕量開放大型語言模型系列,以 Google DeepMind Gemini 的研究和技術為基礎。本教學課程說明如何使用 Gemma 2B Instruct 模型執行基本的取樣/推論作業,這個模型使用 JAX (高效能數值運算程式庫)、Flax (以 JAX 為基礎的類神經網路程式庫)、Orbax (以 JAX 為基礎的 JAX 權杖,以及檢查點等,用於訓練公用程式的程式庫)。gemmaSentencePiece雖然這個筆記本並未直接使用 Flax,但它是用來建立 Gemma。

這個筆記本可以在搭配免費的 T4 GPU 的 Google Colab 上執行 (依序前往「編輯」>「筆記本設定」,然後在「硬體加速器」下方選取「T4 GPU」)。

設定

1. 設定 Gemma 的 Kaggle 存取權

如要完成本教學課程,您必須先按照 Gemma 設定中的操作說明進行設定,瞭解如何執行下列操作:

  • 前往 kaggle.com 存取 Gemma。
  • 請選取具備足夠資源來執行 Gemma 模型的 Colab 執行階段。
  • 產生並設定 Kaggle 使用者名稱與 API 金鑰。

完成 Gemma 設定後,請前往下一節,為 Colab 環境設定環境變數。

2. 設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。出現「授予存取權?」訊息時,請同意提供密鑰存取權。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. 安裝 gemma 程式庫

這個筆記本著重於使用免費的 Colab GPU,如要啟用硬體加速功能,請依序點選「編輯」 >「筆記本設定」 >「T4 GPU」 >「儲存」

接下來,你必須從 github.com/google-deepmind/gemma 安裝 Google DeepMind gemma 程式庫。如果系統顯示「pip 依附元件解析器」的錯誤訊息,通常可以予以忽略。

pip install -q git+https://github.com/google-deepmind/gemma.git

載入並準備 Gemma 模型

  1. 使用 kagglehub.model_download 載入 Gemma 模型,該模型會採用三個引數:
  • handle:Kaggle 的模型控制代碼
  • path:(選用字串) 本機路徑
  • force_download:(選用布林值) 強制重新下載模型
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. 請檢查模型權重和權杖化工具的位置,然後設定路徑變數。符記化工具目錄位於您下載模型的主目錄中,模型權重則位於子目錄中。例如:
  • tokenizer.model 檔案會位於 /LOCAL/PATH/TO/gemma/flax/2b-it/2)。
  • 模型查核點會顯示在 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it 中)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

執行取樣/推論

  1. 使用 gemma.params.load_and_format_params 方法載入 Gemma 模型查核點並設定格式:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. 載入使用 sentencepiece.SentencePieceProcessor 建構的 Gemma 權杖化工具:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. 如要從 Gemma 模型查核點自動載入正確的設定,請使用 gemma.transformer.TransformerConfigcache_size 引數是 Gemma Transformer 快取中的步數。之後,請使用 gemma.transformer.Transformer (繼承自 flax.linen.Module) 將 Gemma 模型例項化為 transformer
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. 在 Gemma 模型查核點/權重和權杖化工具上方,使用 gemma.sampler.Sampler 建立 sampler
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. input_batch 中撰寫提示並執行推論。您可以調整 total_generation_steps (產生回應時執行的步驟數,此範例使用 100 來保留主機記憶體)。
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (選用) 完成筆記本後,如要嘗試其他提示,請執行這個儲存格來釋出記憶體。之後,您可以在步驟 3 再次將 sampler 例項化,然後自訂並執行步驟 4 的提示。
del sampler

瞭解詳情