JAX と Flax を使用した CodeGemma による推論

ai.google.dev で表示 Google Colab で実行 GitHub のソースを表示

CodeGemma をご紹介します。CodeGemma は、Google DeepMind の Gemma モデル(Gemma Team 他、2024)。 CodeGemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。

Gemma の事前トレーニング済みモデルに続き、CodeGemma モデルは アーキテクチャは Gemma モデルファミリーと同じですその結果、CodeGemma モデルは、どちらの補完においても最先端のコード パフォーマンスを達成しています。 生成タスクをサポートしながら、強力な 大規模に理解し、推論する能力が必要です。

CodeGemma には次の 3 つのバリアントがあります。

  • 7B コードの事前トレーニング済みモデル
  • 7B 命令チューニング済みのコードモデル
  • コード インフィルとオープンエンド生成のために特別にトレーニングされた 2B モデル。

このガイドでは、コード補完タスクで CodeGemma モデルを Flax で使用する方法について説明します。

セットアップ

1. CodeGemma の Kaggle アクセスを設定する

このチュートリアルを完了するには、まず、Gemma の設定に記載されている設定手順を実施する必要があります。この手順では、以下を行う方法について説明します。

  • kaggle.com で CodeGemma にアクセスします。
  • 十分なリソースがある Colab ランタイムを選択して(T4 GPU のメモリが不十分です。代わりに TPU v2 を使用してください)、CodeGemma モデルを実行します。
  • Kaggle のユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。

2. 環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを提供することに合意します。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma ライブラリをインストールする

現在、このノートブックを実行するには、無料の Colab ハードウェア アクセラレーションでは不十分です。Colab 従量課金制または Colab Pro を使用している場合は、[編集] をクリックします。ノートブックの設定 >[A100 GPU] を選択 >ハードウェア アクセラレーションを有効にするには、[保存] をクリックします。

次に、github.com/google-deepmind/gemma から Google DeepMind gemma ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視できます。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. ライブラリをインポートする

このノートブックでは、GemmaFlax を使用してニューラル ネットワーク レイヤを構築)と SentencePiece(トークン化)を使用します。

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

CodeGemma モデルを読み込む

kagglehub.model_download を使用して CodeGemma モデルを読み込みます。これは、次の 3 つの引数を取ります。

  • handle: Kaggle のモデルハンドル
  • path: (省略可)ローカルパス
  • force_download: (省略可のブール値)モデルの再ダウンロードを強制的に行います
で確認できます。
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

モデルの重みとトークナイザの場所を確認して、パス変数を設定します。トークナイザー ディレクトリはモデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。例:

  • spm.model トークナイザ ファイルは /LOCAL/PATH/TO/codegemma/flax/2b-pt/3 にあります。
  • モデルのチェックポイントは /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt にあります
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

サンプリング/推論

gemma.params.load_and_format_params メソッドを使用して、CodeGemma モデルのチェックポイントを読み込んでフォーマットします。

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor を使用して作成された CodeGemma トークナイザを読み込みます。

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma モデルのチェックポイントから正しい構成を自動的に読み込むには、gemma.transformer.TransformerConfig を使用します。cache_size 引数は、CodeGemma の Transformer キャッシュ内のタイムステップの数です。その後、gemma.transformer.Transformerflax.linen.Module から継承)を使用して、CodeGemma モデルを model_2b としてインスタンス化します。

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Samplersampler を作成します。CodeGemma モデルのチェックポイントとトークナイザを使用します。

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

中間を埋める(fim)トークンを表す変数をいくつか作成し、プロンプトと生成された出力をフォーマットするヘルパー関数を作成します。

たとえば、次のコードを見てみましょう。

def function(string):
assert function('asdf') == 'fdsa'

アサーションが True を保持するように function を入力します。この場合、接頭辞は次のようになります。

"def function(string):\n"

接尾辞は次のようになります。

"assert function('asdf') == 'fdsa'"

これを PREFIX-SUFFIX-MIDDLE というプロンプトの形式にします(入力する必要がある中央のセクションは常にプロンプトの最後にあります)。

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

プロンプトを作成して推論を実行する。接頭辞 before テキストと接尾辞 after テキストを指定し、ヘルパー関数 format_completion prompt を使用してフォーマット済みプロンプトを生成します。

total_generation_steps(レスポンスの生成時に実行するステップ数。この例では 100 を使用してホストメモリを保持)を微調整できます。

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

その他の情報

  • Google DeepMind の GitHub の gemma ライブラリの詳細を確認できます。これには、このチュートリアルで使用したモジュールの docstring が含まれています(gemma.paramsgemma.transformergemma.sampler
  • core JAXFlaxOrbax の各ライブラリには独自のドキュメント サイトがあります。
  • sentencepiece トークナイザーとデトークナイザーのドキュメントについては、Google の sentencepiece GitHub リポジトリをご覧ください。
  • kagglehub のドキュメントについては、Kaggle の kagglehub GitHub リポジトリREADME.md をご覧ください。
  • Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
  • Google Cloud TPU(v3-8 以降)を使用している場合は、最新の jax[tpu] パッケージ(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html)に更新し、ランタイムを再起動して、jaxjaxlib のバージョンが一致している(!pip list | grep jax)ことを確認してください。これにより、jaxlibjax のバージョンの不一致が原因で発生する RuntimeError を防ぐことができます。JAX のインストール手順の詳細については、JAX のドキュメントをご覧ください。