JAX と Flax を使用した RecurrentGemma による推論

ai.google.dev で表示 Google Colab で実行 Vertex AI で開く GitHub のソースを表示

このチュートリアルでは、JAX(高性能数値計算ライブラリ)、Flax(JAX ベースのニューラル ネットワーク ライブラリ)、Orbax(JAX ベースのトレーニング用ライブラリ)、checkpoint1 などのユーティリティ用の JAX ベースのライブラリ(checkpoint1)で記述された Google DeepMind の recurrentgemma ライブラリを使用して、RecurrentGemma 2B Instruct モデルで基本的なサンプリング / 推論を行う方法を説明します。SentencePieceこのノートブックでは Flax は直接使用されていませんが、Gemma と RecurrentGemma(Griffin モデル)の作成には Flax が使用されました。

このノートブックは、T4 GPU を使用して Google Colab で実行できます([編集] > [ノートブック設定] > [ハードウェア アクセラレータ] で [T4 GPU] を選択します)。

セットアップ

以降のセクションでは、RecurrentGemma モデルを使用するノートブックを準備する手順(モデルへのアクセス、API キーの取得、ノートブック ランタイムの構成など)について説明します。

Gemma 用に Kaggle のアクセスを設定する

このチュートリアルを完了するには、まず Gemma の設定同様の設定手順を行う必要があります。ただし、いくつかの例外があります。

  • kaggle.com で(Gemma ではなく)RecurrentGemma にアクセスしてください。
  • RecurrentGemma モデルの実行に十分なリソースを備えた Colab ランタイムを選択します。
  • Kaggle のユーザー名と API キーを生成して構成します。

RecurrentGemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。

環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを提供することに合意します。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma ライブラリをインストールする

このノートブックでは、無料の Colab GPU の使用に焦点を当てています。ハードウェア アクセラレーションを有効にするには、[編集] >ノートブックの設定 >[T4 GPU] を選択 >保存

次に、github.com/google-deepmind/recurrentgemma から Google DeepMind recurrentgemma ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視できます。

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma モデルを読み込んで準備する

  1. kagglehub.model_download を使用して RecurrentGemma モデルを読み込みます。これは 3 つの引数を取ります。
  • handle: Kaggle のモデルハンドル
  • path: (省略可)ローカルパス
  • force_download: (省略可のブール値)モデルの再ダウンロードを強制的に行います
で確認できます。
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. モデルの重みとトークナイザの場所を確認して、パス変数を設定します。トークナイザー ディレクトリはモデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。例:
  • tokenizer.model ファイルは /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 にあります)。
  • モデルのチェックポイントは /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it にあります)。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

サンプリング/推論

  1. recurrentgemma.jax.load_parameters メソッドを使用して、RecurrentGemma モデルのチェックポイントを読み込みます。sharding 引数を "single_device" に設定すると、1 つのデバイスにすべてのモデル パラメータが読み込まれます。
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. sentencepiece.SentencePieceProcessor を使用して作成された RecurrentGemma モデル トークナイザを読み込みます。
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. RecurrentGemma モデルのチェックポイントから正しい構成を自動的に読み込むには、recurrentgemma.GriffinConfig.from_flax_params_or_variables を使用します。次に、recurrentgemma.jax.Griffin を使用して Griffin モデルをインスタンス化します。
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. RecurrentGemma モデルのチェックポイント/重みとトークナイザの上に recurrentgemma.jax.Sampler を使用して sampler を作成します。
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt でプロンプトを記述し、推論を実行します。total_generation_steps(レスポンスの生成時に実行するステップ数。この例では 50 を使用してホストメモリを保持)を微調整できます。
で確認できます。
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

その他の情報