我們推出 CodeGemma,這是一組開放原始碼模型,以 Google DeepMind 的 Gemma 模型為基礎 (Gemma Team et al., 2024 年)。CodeGemma 是一系列先進的輕量級開放式模型,採用與建立 Gemini 模型時相同的研究成果和技術。
接續 Gemma 預先訓練模型,CodeGemma 模型會進一步針對 500 至 1000 億個主要程式碼符記進行訓練,並使用與 Gemma 模型系列相同的架構。因此,CodeGemma 模型可在完成和產生任務中達到頂尖程式碼效能,同時維持強大的理解和推理技能。
CodeGemma 有 3 種變化版本:
- 7B 程式碼預先訓練模型
- 7B 指令調整的程式碼模型
- 專門針對程式碼補全和開放式生成作業訓練的 2B 模型。
本指南將逐步說明如何使用 CodeGemma 模型搭配 Flax,執行程式碼補全作業。
設定
1. 為 CodeGemma 設定 Kaggle 存取權
如要完成本教學課程,您必須先按照 Gemma 設定中的說明操作,瞭解如何執行下列操作:
- 前往 kaggle.com 取得 CodeGemma 存取權。
- 選取有足夠資源的 Colab 執行階段 (T4 GPU 記憶體不足,請改用 TPU v2) 來執行 CodeGemma 模型。
- 產生及設定 Kaggle 使用者名稱和 API 金鑰。
完成 Gemma 設定後,請繼續閱讀下一節,瞭解如何設定 Colab 環境的環境變數。
2. 設定環境變數
設定 KAGGLE_USERNAME
和 KAGGLE_KEY
的環境變數。當系統提示「授予存取權嗎?」時,請同意提供密鑰存取權。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. 安裝 gemma
程式庫
免費的 Colab 硬體加速功能目前insufficient,無法執行這個筆記本。如果您使用 Colab Pay As You Go 或 Colab Pro,請依序點選「編輯」 >「筆記本設定」 > 選取「A100 GPU」 >「儲存」,啟用硬體加速功能。
接著,您需要從 github.com/google-deepmind/gemma
安裝 Google DeepMind gemma
程式庫。如果您收到有關「pip 的依附元件解析工具」的錯誤訊息,通常可以忽略。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. 匯入程式庫
本筆記本使用 Gemma (使用 Flax 建構神經網路層) 和 SentencePiece (用於標記)。
import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm
載入 CodeGemma 模型
使用 kagglehub.model_download
載入 CodeGemma 模型,該函式需要三個引數:
handle
:Kaggle 的模型句柄path
:(選用字串) 本機路徑force_download
:(選用布林值) 強制重新下載模型
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
檢查模型權重和分析器的位置,然後設定路徑變數。分詞器目錄會位於您下載模型的主目錄中,而模型權重會位於子目錄中。例如:
spm.model
分詞器檔案會位於/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- 模型查核點會位於
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
執行取樣/推論
使用 gemma.params.load_and_format_params
方法載入及格式化 CodeGemma 模型查核點:
params = params_lib.load_and_format_params(CKPT_PATH)
載入 CodeGemma 分詞器,該分詞器是使用 sentencepiece.SentencePieceProcessor
建構而成:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
如要自動從 CodeGemma 模型檢查點載入正確的設定,請使用 gemma.deprecated.transformer.TransformerConfig
。cache_size
引數是 CodeGemma Transformer
快取中的時間步數。接著,請使用 gemma.deprecated.transformer.Transformer
(繼承自 flax.linen.Module
) 將 CodeGemma 模型例項化為 model_2b
。
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
使用 gemma.sampler.Sampler
建立 sampler
。它會使用 CodeGemma 模型查核點和剖析器。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
建立一些變數來代表填入中間 (FIM) 符記,並建立一些輔助函式來格式化提示和產生的輸出內容。
舉例來說,請看下列程式碼:
def function(string):
assert function('asdf') == 'fdsa'
我們想填入 function
,讓斷言保留 True
。在這種情況下,前置字串會是:
"def function(string):\n"
後置字串則為:
"assert function('asdf') == 'fdsa'"
然後,我們將這段文字格式化為提示,格式為「前置字串-後置字串-中間」(需要填入內容的中間部分一律會出現在提示結尾):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
建立提示並執行推論。指定前置字串 before
和後置字串 after
,並使用輔助函式 format_completion prompt
產生格式化的提示。
您可以調整 total_generation_steps
(產生回應時執行的步驟數量,這個範例使用 100
來保留主機記憶體)。
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
瞭解詳情
- 您可以在 GitHub 上進一步瞭解 Google DeepMind
gemma
程式庫,其中包含您在本教學課程中使用的模組的 docstring,例如gemma.params
、gemma.deprecated.transformer
和gemma.sampler
。 - 下列程式庫有各自的說明文件網站:核心 JAX、Flax 和 Orbax。
- 如需
sentencepiece
分詞器/解分詞器說明文件,請查看 Google 的sentencepiece
GitHub 存放區。 - 如需
kagglehub
說明文件,請前往 Kaggle 的kagglehub
GitHub 存放區查看README.md
。 - 瞭解如何搭配 Google Cloud Vertex AI 使用 Gemma 模型。
- 如果您使用 Google Cloud TPU (v3-8 以上版本),請務必一併更新至最新的
jax[tpu]
套件 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)、重新啟動執行階段,並檢查jax
和jaxlib
版本是否相符 (!pip list | grep jax
)。這樣做可避免因jaxlib
和jax
版本不相符而導致RuntimeError
發生。如需更多 JAX 安裝操作說明,請參閱 JAX 說明文件。