JAX と Flax を使用した CodeGemma による推論

Google DeepMind の Gemma モデル(Gemma チームら、2024 年)。CodeGemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。

Gemma の事前トレーニング済みモデルに続く CodeGemma モデルは、Gemma モデル ファミリーと同じアーキテクチャを使用して、主にコードの 500 ~ 1,000 億トークンでさらにトレーニングされます。その結果、CodeGemma モデルは、補完タスクと生成タスクの両方で最先端のコード パフォーマンスを達成しながら、大規模な理解と推論スキルを維持します。

CodeGemma には 3 つのバリアントがあります。

  • 70 億コードの事前トレーニング済みモデル
  • 7B 指示チューニング済みコードモデル
  • コード補完と自由形式の生成用に特別にトレーニングされた 2B モデル。

このガイドでは、コード補完タスクに Flax で CodeGemma モデルを使用する方法について説明します。

セットアップ

1. CodeGemma の Kaggle アクセスを設定する

このチュートリアルを完了するには、まず Gemma の設定の手順に沿って、次の操作を行う必要があります。

  • kaggle.com で CodeGemma にアクセスします。
  • 十分なリソース(T4 GPU のメモリが不足している場合は、代わりに TPU v2 を使用)を持つ Colab ランタイムを選択し、CodeGemma モデルを実行します。
  • Kaggle のユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら、次のセクションに進みます。ここでは、Colab 環境の環境変数を設定します。

2. 環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを許可します。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma ライブラリをインストールする

現在、無料の Colab ハードウェア アクセラレーションでは、このノートブックを実行するのにinsufficientColab 従量課金制または Colab Pro を使用している場合は、[編集] > [ノートブックの設定] > [A100 GPU] > [保存] の順にクリックして、ハードウェア アクセラレーションを有効にします。

次に、github.com/google-deepmind/gemma から Google DeepMind gemma ライブラリをインストールする必要があります。「pip の依存関係解決ツール」に関するエラーが表示された場合は、通常は無視できます。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. ライブラリをインポートする

このノートブックでは、GemmaFlax を使用してニューラル ネットワーク レイヤを構築)と SentencePiece(トークン化用)を使用します。

import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm

CodeGemma モデルを読み込む

kagglehub.model_download を使用して CodeGemma モデルを読み込みます。この関数は 3 つの引数を受け取ります。

  • handle: Kaggle のモデルハンドル
  • path: (省略可の文字列)ローカルパス
  • force_download: (省略可、ブール値)モデルの再ダウンロードを強制します。
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

モデルの重みとトークンの場所を確認し、パス変数を設定します。トークン ディレクトリは、モデルをダウンロードしたメイン ディレクトリに配置されます。モデルの重み付けはサブディレクトリに配置されます。次に例を示します。

  • spm.model トークン化ファイルは /LOCAL/PATH/TO/codegemma/flax/2b-pt/3 にあります。
  • モデルのチェックポイントは /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt にあります。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

サンプリング/推論を実行する

gemma.params.load_and_format_params メソッドを使用して CodeGemma モデルのチェックポイントを読み込み、フォーマットします。

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor を使用して作成された CodeGemma トークン化ツールを読み込みます。

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma モデルのチェックポイントから正しい構成を自動的に読み込むには、gemma.deprecated.transformer.TransformerConfig を使用します。cache_size 引数は、CodeGemma Transformer キャッシュ内の時間ステップ数です。次に、gemma.deprecated.transformer.Transformerflax.linen.Module から継承)を使用して、CodeGemma モデルを model_2b としてインスタンス化します。

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Samplersampler を作成します。CodeGemma モデルのチェックポイントとトークン化ツールを使用します。

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

空白部分に入力する(fim)トークンを表す変数を作成します。また、プロンプトと生成された出力をフォーマットするヘルパー関数を作成します。

たとえば、次のコードを見てみましょう。

def function(string):
assert function('asdf') == 'fdsa'

アサーションに True が保持されるように function に値を入力します。この場合、接頭辞は次のようになります。

"def function(string):\n"

接尾辞は次のようになります。

"assert function('asdf') == 'fdsa'"

次に、このプロンプトを PREFIX-SUFFIX-MIDDLE という形式でフォーマットします(入力が必要な中間部分は常にプロンプトの末尾にあります)。

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

プロンプトを作成して推論を実行します。接頭辞 before テキストと接尾辞 after テキストを指定し、ヘルパー関数 format_completion prompt を使用して書式設定されたプロンプトを生成します。

total_generation_steps(レスポンスの生成時に実行されるステップ数)を調整できます。この例では、100 を使用してホストメモリを保持します。

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

その他の情報

  • Google DeepMind の gemma ライブラリ(GitHub)で詳細を確認できます。このライブラリには、このチュートリアルで使用したモジュールのドキュメント(gemma.paramsgemma.deprecated.transformergemma.sampler など)が含まれています。
  • コア JAXFlaxOrbax の各ライブラリには、独自のドキュメント サイトがあります。
  • sentencepiece トークン化ツール/トークン解除ツールのドキュメントについては、Google の sentencepiece GitHub リポジトリをご覧ください。
  • kagglehub のドキュメントについては、Kaggle の kagglehub GitHub リポジトリREADME.md をご覧ください。
  • Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
  • Google Cloud TPU(v3-8 以降)を使用している場合は、最新の jax[tpu] パッケージ(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html)に更新し、ランタイムを再起動して、jaxjaxlib のバージョンが一致していることを確認します(!pip list | grep jax)。これにより、jaxlibjax のバージョンの不一致が原因で発生する RuntimeError を防ぐことができます。JAX のインストール手順については、JAX のドキュメントをご覧ください。