Google DeepMind の Gemma モデル(Gemma チームら、2024 年)。CodeGemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。
Gemma の事前トレーニング済みモデルに続く CodeGemma モデルは、Gemma モデル ファミリーと同じアーキテクチャを使用して、主にコードの 500 ~ 1,000 億トークンでさらにトレーニングされます。その結果、CodeGemma モデルは、補完タスクと生成タスクの両方で最先端のコード パフォーマンスを達成しながら、大規模な理解と推論スキルを維持します。
CodeGemma には 3 つのバリアントがあります。
- 70 億コードの事前トレーニング済みモデル
- 7B 指示チューニング済みコードモデル
- コード補完と自由形式の生成用に特別にトレーニングされた 2B モデル。
このガイドでは、コード補完タスクに Flax で CodeGemma モデルを使用する方法について説明します。
セットアップ
1. CodeGemma の Kaggle アクセスを設定する
このチュートリアルを完了するには、まず Gemma の設定の手順に沿って、次の操作を行う必要があります。
- kaggle.com で CodeGemma にアクセスします。
- 十分なリソース(T4 GPU のメモリが不足している場合は、代わりに TPU v2 を使用)を持つ Colab ランタイムを選択し、CodeGemma モデルを実行します。
- Kaggle のユーザー名と API キーを生成して構成します。
Gemma の設定が完了したら、次のセクションに進みます。ここでは、Colab 環境の環境変数を設定します。
2. 環境変数を設定する
KAGGLE_USERNAME
と KAGGLE_KEY
の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを許可します。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. gemma
ライブラリをインストールする
現在、無料の Colab ハードウェア アクセラレーションでは、このノートブックを実行するのにinsufficient。Colab 従量課金制または Colab Pro を使用している場合は、[編集] > [ノートブックの設定] > [A100 GPU] > [保存] の順にクリックして、ハードウェア アクセラレーションを有効にします。
次に、github.com/google-deepmind/gemma
から Google DeepMind gemma
ライブラリをインストールする必要があります。「pip の依存関係解決ツール」に関するエラーが表示された場合は、通常は無視できます。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. ライブラリをインポートする
このノートブックでは、Gemma(Flax を使用してニューラル ネットワーク レイヤを構築)と SentencePiece(トークン化用)を使用します。
import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm
CodeGemma モデルを読み込む
kagglehub.model_download
を使用して CodeGemma モデルを読み込みます。この関数は 3 つの引数を受け取ります。
handle
: Kaggle のモデルハンドルpath
: (省略可の文字列)ローカルパスforce_download
: (省略可、ブール値)モデルの再ダウンロードを強制します。
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
モデルの重みとトークンの場所を確認し、パス変数を設定します。トークン ディレクトリは、モデルをダウンロードしたメイン ディレクトリに配置されます。モデルの重み付けはサブディレクトリに配置されます。次に例を示します。
spm.model
トークン化ファイルは/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
にあります。- モデルのチェックポイントは
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
にあります。
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
サンプリング/推論を実行する
gemma.params.load_and_format_params
メソッドを使用して CodeGemma モデルのチェックポイントを読み込み、フォーマットします。
params = params_lib.load_and_format_params(CKPT_PATH)
sentencepiece.SentencePieceProcessor
を使用して作成された CodeGemma トークン化ツールを読み込みます。
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
CodeGemma モデルのチェックポイントから正しい構成を自動的に読み込むには、gemma.deprecated.transformer.TransformerConfig
を使用します。cache_size
引数は、CodeGemma Transformer
キャッシュ内の時間ステップ数です。次に、gemma.deprecated.transformer.Transformer
(flax.linen.Module
から継承)を使用して、CodeGemma モデルを model_2b
としてインスタンス化します。
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
gemma.sampler.Sampler
で sampler
を作成します。CodeGemma モデルのチェックポイントとトークン化ツールを使用します。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
空白部分に入力する(fim)トークンを表す変数を作成します。また、プロンプトと生成された出力をフォーマットするヘルパー関数を作成します。
たとえば、次のコードを見てみましょう。
def function(string):
assert function('asdf') == 'fdsa'
アサーションに True
が保持されるように function
に値を入力します。この場合、接頭辞は次のようになります。
"def function(string):\n"
接尾辞は次のようになります。
"assert function('asdf') == 'fdsa'"
次に、このプロンプトを PREFIX-SUFFIX-MIDDLE という形式でフォーマットします(入力が必要な中間部分は常にプロンプトの末尾にあります)。
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
プロンプトを作成して推論を実行します。接頭辞 before
テキストと接尾辞 after
テキストを指定し、ヘルパー関数 format_completion prompt
を使用して書式設定されたプロンプトを生成します。
total_generation_steps
(レスポンスの生成時に実行されるステップ数)を調整できます。この例では、100
を使用してホストメモリを保持します。
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
その他の情報
- Google DeepMind の
gemma
ライブラリ(GitHub)で詳細を確認できます。このライブラリには、このチュートリアルで使用したモジュールのドキュメント(gemma.params
、gemma.deprecated.transformer
、gemma.sampler
など)が含まれています。 - コア JAX、Flax、Orbax の各ライブラリには、独自のドキュメント サイトがあります。
sentencepiece
トークン化ツール/トークン解除ツールのドキュメントについては、Google のsentencepiece
GitHub リポジトリをご覧ください。kagglehub
のドキュメントについては、Kaggle のkagglehub
GitHub リポジトリでREADME.md
をご覧ください。- Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
- Google Cloud TPU(v3-8 以降)を使用している場合は、最新の
jax[tpu]
パッケージ(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)に更新し、ランタイムを再起動して、jax
とjaxlib
のバージョンが一致していることを確認します(!pip list | grep jax
)。これにより、jaxlib
とjax
のバージョンの不一致が原因で発生するRuntimeError
を防ぐことができます。JAX のインストール手順については、JAX のドキュメントをご覧ください。