JAX と Flax を使用した Gemma による推論

ai.google.dev で表示 Google Colab で実行 Vertex AI で開く GitHub でソースを表示

概要

Gemma は、Google DeepMind Gemini の研究とテクノロジーをベースにした、軽量で最先端のオープン 大規模言語モデルのファミリーです。このチュートリアルでは、JAX(高性能数値計算ライブラリ)、Flax(JAX ベースのニューラル ネットワーク ライブラリ)、Orbax(チェックポインティング / トークン化ツールなどのトレーニング ユーティリティ用の JAX ベースのライブラリ{/10)で作成された Google DeepMind の gemma ライブラリを使用して、Gemma 2B Instruct モデルで基本的なサンプリング / 推論を行う方法を示します。SentencePieceこのノートブックでは Flax は直接使用されていませんが、Gemma の作成には Flax が使用されました。

このノートブックは、無料の T4 GPU を使用して Google Colab で実行できます([編集] > [ノートブックの設定] > [ハードウェア アクセラレータ] で [T4 GPU] を選択します)。

設定

1. Gemma 用に Kaggle のアクセス権を設定する

このチュートリアルを完了するには、まず Gemma のセットアップに記載されている手順に沿って操作する必要があります。ここでは、以下の方法について説明しています。

  • kaggle.com で Gemma にアクセスできます。
  • Gemma モデルの実行に十分なリソースがある Colab ランタイムを選択します。
  • Kaggle ユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら次のセクションに進み、Colab 環境の環境変数を設定します。

2. 環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。「アクセスを許可しますか?」というメッセージが表示されたら、シークレット アクセスの提供に同意します。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma ライブラリをインストールする

このノートブックでは、無料の Colab GPU の使用に焦点を当てています。ハードウェア アクセラレーションを有効にするには、[編集] > [ノートブック設定] > [T4 GPU] > [保存] の順にクリックします。

次に、github.com/google-deepmind/gemma から Google DeepMind gemma ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視してかまいません。

pip install -q git+https://github.com/google-deepmind/gemma.git

Gemma モデルを読み込んで準備する

  1. kagglehub.model_download を使用して Gemma モデルを読み込みます。これは 3 つの引数を取ります。
  • handle: Kaggle のモデルハンドル
  • path: (省略可能な文字列)ローカルパス
  • force_download: (ブール値)(省略可)モデルを強制的に再ダウンロードします。
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. モデルの重みとトークナイザの場所を確認してから、パス変数を設定します。トークナイザ ディレクトリは、モデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。次に例を示します。
  • tokenizer.model ファイルは /LOCAL/PATH/TO/gemma/flax/2b-it/2 にあります)。
  • モデルのチェックポイントは /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it にあります)。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

サンプリング/推論を実行する

  1. gemma.params.load_and_format_params メソッドを使用して、Gemma モデル チェックポイントを読み込み、フォーマットします。
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. sentencepiece.SentencePieceProcessor を使用して作成された Gemma トークナイザを読み込みます。
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Gemma モデル チェックポイントから正しい構成を自動的に読み込むには、gemma.transformer.TransformerConfig を使用します。cache_size 引数は、Gemma Transformer キャッシュ内のタイムステップの数です。その後、gemma.transformer.Transformerflax.linen.Module から継承)を使用して、Gemma モデルを transformer としてインスタンス化します。
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Gemma モデルのチェックポイント/重みとトークナイザの上に、gemma.sampler.Sampler を使用して sampler を作成します。
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. input_batch でプロンプトを作成し、推論を実行します。total_generation_steps(レスポンスの生成時に実行されるステップ数)を微調整できます。この例では、100 を使用してホストのメモリを節約しています。
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (省略可)ノートブックの作成が完了し、別のプロンプトを試す場合は、このセルを実行してメモリを解放します。その後、手順 3 で sampler を再度インスタンス化し、手順 4 でプロンプトをカスタマイズして実行できます。
del sampler

詳細