前往 ai.google.dev 查看 | 在 Google Colab 中執行 | 前往 GitHub 查看原始碼 |
發表 CodeGemma,這是一組以 Google DeepMind 的 Gemma 模型為基礎 (Gemma Team 等人,2024 年)。 CodeGemma 是一系列先進的開放式模型,與建立 Gemini 模型所使用的研究和技術相同。
延續 Gemma 預先訓練模型,CodeGemma 模型進一步使用 500 至 1,0000 億個主要程式碼的權杖訓練, 和 Gemma 模型系列採用的架構相同因此,CodeGemma 模型在完成兩個任務時都能提供最先進的程式碼效能 同時確保 大規模理解及推理技能
CodeGemma 有 3 個變化版本:
- 70 億個程式碼的預先訓練模型
- 70 億個用於指令調整的程式碼模型
- 2B 模型,專門訓練用於程式碼填充和開放式生成技術。
本指南將逐步引導您搭配 Flax 使用 CodeGemma 模型,完成程式碼完成工作。
設定
1. 設定 CodeGemma 的 Kaggle 存取權
為完成本教學課程,您必須先按照 Gemma 設定中的設定說明操作,其中將說明如何執行下列操作:
- 立即前往 kaggle.com 存取 CodeGemma。
- 請選取資源充足的 Colab 執行階段 (T4 GPU 記憶體不足,改用 TPU v2) 來執行 CodeGemma 模型。
- 產生並設定 Kaggle 使用者名稱和 API 金鑰。
完成 Gemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。
2. 設定環境變數
設定 KAGGLE_USERNAME
和 KAGGLE_KEY
的環境變數。系統顯示「授予存取權?」提示訊息時,訊息,同意提供密鑰存取權。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. 安裝 gemma
程式庫
目前免費的 Colab 硬體加速功能不足,無法執行這個筆記本。如果使用 Colab Pay As You Go 或 Colab Pro,請按一下「編輯」>筆記本設定 >依序選取「A100 GPU」>按一下「儲存」即可啟用硬體加速功能。
接下來,您需要從 github.com/google-deepmind/gemma
安裝 Google DeepMind gemma
程式庫。如果收到有關「pip 依附元件解析器」的錯誤,通常可以忽略。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. 匯入程式庫
這個筆記本使用 Gemma (使用 Flax 建構類神經網路層) 和 SentencePiece (用於權杖化)。
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
載入 CodeGemma 模型
使用 kagglehub.model_download
載入 CodeGemma 模型,該模型會使用三個引數:
handle
:Kaggle 的模型控制代碼path
:(選用字串) 本機路徑force_download
:(選用布林值) 強制重新下載模型
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
檢查模型權重的位置和符記化工具,然後設定路徑變數。符記化工具目錄會存放您下載模型的主要目錄,而模型權重則位於子目錄。例如:
spm.model
符記化工具檔案將位於/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
中- 模型查核點位於
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
執行取樣/推論
使用 gemma.params.load_and_format_params
方法載入 CodeGemma 模型查核點並設定格式:
params = params_lib.load_and_format_params(CKPT_PATH)
載入使用 sentencepiece.SentencePieceProcessor
建構的 CodeGemma 符記化工具:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
如要從 CodeGemma 模型查核點自動載入正確的設定,請使用 gemma.transformer.TransformerConfig
。cache_size
引數是指 CodeGemma Transformer
快取中的時間長度步驟數。之後,請使用 gemma.transformer.Transformer
(繼承自 flax.linen.Module
) 將 CodeGemma 模型例項化為 model_2b
。
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
使用 gemma.sampler.Sampler
建立 sampler
。並會使用 CodeGemma 模型查核點和符記化工具。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
建立一些變數來代表中間 (fim) 符記,並建立一些輔助函式,以設定提示格式和產生的輸出內容。
以下列程式碼為例:
def function(string):
assert function('asdf') == 'fdsa'
我們要填入 function
,讓斷言保留 True
。在此情況下,前置字元會是:
"def function(string):\n"
後置字串會是:
"assert function('asdf') == 'fdsa'"
接著,我們會將此格式的格式設為 PREFIX-SUFFIX-MIDDLE (需填入的中間部分一律在提示結尾):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
建立提示並執行推論。指定前置字串 before
文字和後置字串 after
文字,然後使用 format_completion prompt
輔助函式產生格式化提示。
您可以調整 total_generation_steps
(產生回應時執行的步驟數,此範例使用 100
來保留主機記憶體)。
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
瞭解詳情
- 您可以進一步瞭解 GitHub 上的 Google DeepMind 程式庫
gemma
,其中包含您在本教學課程中使用的模組 docstring,例如gemma.params
、gemma.transformer
和gemma.sampler
。 - 下列程式庫都有專屬的說明文件網站:核心 JAX、Flax 以及 Orbax。
- 如需
sentencepiece
權杖化工具/解碼器說明文件,請前往 Google 的sentencepiece
GitHub 存放區。 - 如需
kagglehub
說明文件,請查看 Kagglekagglehub
GitHub 存放區中的README.md
。 - 瞭解如何搭配使用 Gemma 模型與 Google Cloud Vertex AI。
- 如果您使用的是 Google Cloud TPU (v3-8 及以上版本),請務必一併更新至最新的
jax[tpu]
套件 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
)、重新啟動執行階段,並檢查jax
和jaxlib
版本是否相符 (!pip list | grep jax
)。這麼做可以避免因jaxlib
和jax
版本不符而可能發生的RuntimeError
。如需更多 JAX 安裝操作說明,請參閱 JAX 文件。