ai.google.dev で表示 | Google Colab で実行 | Vertex AI で開く | GitHub のソースを表示 |
このチュートリアルでは、JAX(高性能数値計算ライブラリ)、Flax(JAX ベースのニューラル ネットワーク ライブラリ)、Orbax(JAX ベースのトレーニング用ライブラリ)、checkpoint1 などのユーティリティ用の JAX ベースのライブラリ(checkpoint1)で記述された Google DeepMind の recurrentgemma
ライブラリを使用して、RecurrentGemma 2B Instruct モデルで基本的なサンプリング / 推論を行う方法を説明します。SentencePieceこのノートブックでは Flax は直接使用されていませんが、Gemma と RecurrentGemma(Griffin モデル)の作成には Flax が使用されました。
このノートブックは、T4 GPU を使用して Google Colab で実行できます([編集] > [ノートブック設定] > [ハードウェア アクセラレータ] で [T4 GPU] を選択します)。
セットアップ
以降のセクションでは、RecurrentGemma モデルを使用するノートブックを準備する手順(モデルへのアクセス、API キーの取得、ノートブック ランタイムの構成など)について説明します。
Gemma 用に Kaggle のアクセスを設定する
このチュートリアルを完了するには、まず Gemma の設定と同様の設定手順を行う必要があります。ただし、いくつかの例外があります。
- kaggle.com で(Gemma ではなく)RecurrentGemma にアクセスしてください。
- RecurrentGemma モデルの実行に十分なリソースを備えた Colab ランタイムを選択します。
- Kaggle のユーザー名と API キーを生成して構成します。
RecurrentGemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。
環境変数を設定する
KAGGLE_USERNAME
と KAGGLE_KEY
の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを提供することに合意します。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
recurrentgemma
ライブラリをインストールする
このノートブックでは、無料の Colab GPU の使用に焦点を当てています。ハードウェア アクセラレーションを有効にするには、[編集] >ノートブックの設定 >[T4 GPU] を選択 >保存。
次に、github.com/google-deepmind/recurrentgemma
から Google DeepMind recurrentgemma
ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視できます。
pip install git+https://github.com/google-deepmind/recurrentgemma.git
RecurrentGemma モデルを読み込んで準備する
kagglehub.model_download
を使用して RecurrentGemma モデルを読み込みます。これは 3 つの引数を取ります。
handle
: Kaggle のモデルハンドルpath
: (省略可)ローカルパスforce_download
: (省略可のブール値)モデルの再ダウンロードを強制的に行います
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- モデルの重みとトークナイザの場所を確認して、パス変数を設定します。トークナイザー ディレクトリはモデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。例:
tokenizer.model
ファイルは/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
にあります)。- モデルのチェックポイントは
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
にあります)。
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
サンプリング/推論
recurrentgemma.jax.load_parameters
メソッドを使用して、RecurrentGemma モデルのチェックポイントを読み込みます。sharding
引数を"single_device"
に設定すると、1 つのデバイスにすべてのモデル パラメータが読み込まれます。
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
sentencepiece.SentencePieceProcessor
を使用して作成された RecurrentGemma モデル トークナイザを読み込みます。
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- RecurrentGemma モデルのチェックポイントから正しい構成を自動的に読み込むには、
recurrentgemma.GriffinConfig.from_flax_params_or_variables
を使用します。次に、recurrentgemma.jax.Griffin
を使用して Griffin モデルをインスタンス化します。
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- RecurrentGemma モデルのチェックポイント/重みとトークナイザの上に
recurrentgemma.jax.Sampler
を使用してsampler
を作成します。
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
prompt
でプロンプトを記述し、推論を実行します。total_generation_steps
(レスポンスの生成時に実行するステップ数。この例では50
を使用してホストメモリを保持)を微調整できます。
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
その他の情報
- Google DeepMind の GitHub の
recurrentgemma
ライブラリについて詳しくは、GitHub のrecurrentgemma
ライブラリをご覧ください。このライブラリには、このチュートリアルで使用したメソッドやモジュールのドキュメント文字列(recurrentgemma.jax.load_parameters
、recurrentgemma.jax.Griffin
、recurrentgemma.jax.Sampler
など)が含まれています。 - core JAX、Flax、Orbax の各ライブラリには独自のドキュメント サイトがあります。
sentencepiece
トークナイザーとデトークナイザーのドキュメントについては、Google のsentencepiece
GitHub リポジトリをご覧ください。kagglehub
のドキュメントについては、Kaggle のkagglehub
GitHub リポジトリでREADME.md
をご覧ください。- Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
- 「RecurrentGemma: Moving Past Transformers」 「Efficient Open Language Models」レポートを作成しました。
- 「Griffin: Mixing Gated Linear Recurrences with GoogleDeepMind による Local Attention for Efficient Language Models という論文で、RecurrentGemma で使用されるモデル アーキテクチャについて詳しく説明しています。