ai.google.dev で表示 | Google Colab で実行 | GitHub のソースを表示 |
CodeGemma をご紹介します。CodeGemma は、Google DeepMind の Gemma モデル(Gemma Team 他、2024)。 CodeGemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。
Gemma の事前トレーニング済みモデルに続き、CodeGemma モデルは アーキテクチャは Gemma モデルファミリーと同じですその結果、CodeGemma モデルは、どちらの補完においても最先端のコード パフォーマンスを達成しています。 生成タスクをサポートしながら、強力な 大規模に理解し、推論する能力が必要です。
CodeGemma には次の 3 つのバリアントがあります。
- 7B コードの事前トレーニング済みモデル
- 7B 命令チューニング済みのコードモデル
- コード インフィルとオープンエンド生成のために特別にトレーニングされた 2B モデル。
このガイドでは、コード補完タスクで CodeGemma モデルを Flax で使用する方法について説明します。
セットアップ
1. CodeGemma の Kaggle アクセスを設定する
このチュートリアルを完了するには、まず、Gemma の設定に記載されている設定手順を実施する必要があります。この手順では、以下を行う方法について説明します。
- kaggle.com で CodeGemma にアクセスします。
- 十分なリソースがある Colab ランタイムを選択して(T4 GPU のメモリが不十分です。代わりに TPU v2 を使用してください)、CodeGemma モデルを実行します。
- Kaggle のユーザー名と API キーを生成して構成します。
Gemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。
2. 環境変数を設定する
KAGGLE_USERNAME
と KAGGLE_KEY
の環境変数を設定します。[アクセスを許可しますか?] というメッセージが表示されたら、シークレットへのアクセスを提供することに合意します。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. gemma
ライブラリをインストールする
現在、このノートブックを実行するには、無料の Colab ハードウェア アクセラレーションでは不十分です。Colab 従量課金制または Colab Pro を使用している場合は、[編集] をクリックします。ノートブックの設定 >[A100 GPU] を選択 >ハードウェア アクセラレーションを有効にするには、[保存] をクリックします。
次に、github.com/google-deepmind/gemma
から Google DeepMind gemma
ライブラリをインストールする必要があります。「pip の依存関係リゾルバ」に関するエラーが発生した場合、通常は無視できます。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. ライブラリをインポートする
このノートブックでは、Gemma(Flax を使用してニューラル ネットワーク レイヤを構築)と SentencePiece(トークン化)を使用します。
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
CodeGemma モデルを読み込む
kagglehub.model_download
を使用して CodeGemma モデルを読み込みます。これは、次の 3 つの引数を取ります。
handle
: Kaggle のモデルハンドルpath
: (省略可)ローカルパスforce_download
: (省略可のブール値)モデルの再ダウンロードを強制的に行います
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
モデルの重みとトークナイザの場所を確認して、パス変数を設定します。トークナイザー ディレクトリはモデルをダウンロードしたメイン ディレクトリにありますが、モデルの重みはサブディレクトリにあります。例:
spm.model
トークナイザ ファイルは/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
にあります。- モデルのチェックポイントは
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
にあります
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
サンプリング/推論
gemma.params.load_and_format_params
メソッドを使用して、CodeGemma モデルのチェックポイントを読み込んでフォーマットします。
params = params_lib.load_and_format_params(CKPT_PATH)
sentencepiece.SentencePieceProcessor
を使用して作成された CodeGemma トークナイザを読み込みます。
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
CodeGemma モデルのチェックポイントから正しい構成を自動的に読み込むには、gemma.transformer.TransformerConfig
を使用します。cache_size
引数は、CodeGemma の Transformer
キャッシュ内のタイムステップの数です。その後、gemma.transformer.Transformer
(flax.linen.Module
から継承)を使用して、CodeGemma モデルを model_2b
としてインスタンス化します。
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
gemma.sampler.Sampler
で sampler
を作成します。CodeGemma モデルのチェックポイントとトークナイザを使用します。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
中間を埋める(fim)トークンを表す変数をいくつか作成し、プロンプトと生成された出力をフォーマットするヘルパー関数を作成します。
たとえば、次のコードを見てみましょう。
def function(string):
assert function('asdf') == 'fdsa'
アサーションが True
を保持するように function
を入力します。この場合、接頭辞は次のようになります。
"def function(string):\n"
接尾辞は次のようになります。
"assert function('asdf') == 'fdsa'"
これを PREFIX-SUFFIX-MIDDLE というプロンプトの形式にします(入力する必要がある中央のセクションは常にプロンプトの最後にあります)。
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
プロンプトを作成して推論を実行する。接頭辞 before
テキストと接尾辞 after
テキストを指定し、ヘルパー関数 format_completion prompt
を使用してフォーマット済みプロンプトを生成します。
total_generation_steps
(レスポンスの生成時に実行するステップ数。この例では 100
を使用してホストメモリを保持)を微調整できます。
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
その他の情報
- Google DeepMind の GitHub の
gemma
ライブラリの詳細を確認できます。これには、このチュートリアルで使用したモジュールの docstring が含まれています(gemma.params
、gemma.transformer
、gemma.sampler
。 - core JAX、Flax、Orbax の各ライブラリには独自のドキュメント サイトがあります。
sentencepiece
トークナイザーとデトークナイザーのドキュメントについては、Google のsentencepiece
GitHub リポジトリをご覧ください。kagglehub
のドキュメントについては、Kaggle のkagglehub
GitHub リポジトリでREADME.md
をご覧ください。- Google Cloud Vertex AI で Gemma モデルを使用する方法を学習する。
- Google Cloud TPU(v3-8 以降)を使用している場合は、最新の
jax[tpu]
パッケージ(!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)に更新し、ランタイムを再起動して、jax
とjaxlib
のバージョンが一致している(!pip list | grep jax
)ことを確認してください。これにより、jaxlib
とjax
のバージョンの不一致が原因で発生するRuntimeError
を防ぐことができます。JAX のインストール手順の詳細については、JAX のドキュメントをご覧ください。