使用 JAX 和 Flax 透過 CodeGemma 進行推論

我們推出 CodeGemma,這是一組開放原始碼模型,以 Google DeepMind 的 Gemma 模型為基礎 (Gemma Team et al., 2024 年)。CodeGemma 是一系列先進的輕量級開放式模型,採用與建立 Gemini 模型時相同的研究成果和技術。

接續 Gemma 預先訓練模型,CodeGemma 模型會進一步針對 500 至 1000 億個主要程式碼符記進行訓練,並使用與 Gemma 模型系列相同的架構。因此,CodeGemma 模型可在完成和產生任務中達到頂尖程式碼效能,同時維持強大的理解和推理技能。

CodeGemma 有 3 種變化版本:

  • 7B 程式碼預先訓練模型
  • 7B 指令調整的程式碼模型
  • 專門針對程式碼補全和開放式生成作業訓練的 2B 模型。

本指南將逐步說明如何使用 CodeGemma 模型搭配 Flax,執行程式碼補全作業。

設定

1. 為 CodeGemma 設定 Kaggle 存取權

如要完成本教學課程,您必須先按照 Gemma 設定中的說明操作,瞭解如何執行下列操作:

  • 前往 kaggle.com 取得 CodeGemma 存取權。
  • 選取有足夠資源的 Colab 執行階段 (T4 GPU 記憶體不足,請改用 TPU v2) 來執行 CodeGemma 模型。
  • 產生及設定 Kaggle 使用者名稱和 API 金鑰。

完成 Gemma 設定後,請繼續閱讀下一節,瞭解如何設定 Colab 環境的環境變數。

2. 設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。當系統提示「授予存取權嗎?」時,請同意提供密鑰存取權。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. 安裝 gemma 程式庫

免費的 Colab 硬體加速功能目前insufficient,無法執行這個筆記本。如果您使用 Colab Pay As You Go 或 Colab Pro,請依序點選「編輯」 >「筆記本設定」 > 選取「A100 GPU」 >「儲存」,啟用硬體加速功能。

接著,您需要從 github.com/google-deepmind/gemma 安裝 Google DeepMind gemma 程式庫。如果您收到有關「pip 的依附元件解析工具」的錯誤訊息,通常可以忽略。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 匯入程式庫

本筆記本使用 Gemma (使用 Flax 建構神經網路層) 和 SentencePiece (用於標記)。

import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm

載入 CodeGemma 模型

使用 kagglehub.model_download 載入 CodeGemma 模型,該函式需要三個引數:

  • handle:Kaggle 的模型句柄
  • path:(選用字串) 本機路徑
  • force_download:(選用布林值) 強制重新下載模型
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

檢查模型權重和分析器的位置,然後設定路徑變數。分詞器目錄會位於您下載模型的主目錄中,而模型權重會位於子目錄中。例如:

  • spm.model 分詞器檔案會位於 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • 模型查核點會位於 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

執行取樣/推論

使用 gemma.params.load_and_format_params 方法載入及格式化 CodeGemma 模型查核點:

params = params_lib.load_and_format_params(CKPT_PATH)

載入 CodeGemma 分詞器,該分詞器是使用 sentencepiece.SentencePieceProcessor 建構而成:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

如要自動從 CodeGemma 模型檢查點載入正確的設定,請使用 gemma.deprecated.transformer.TransformerConfigcache_size 引數是 CodeGemma Transformer 快取中的時間步數。接著,請使用 gemma.deprecated.transformer.Transformer (繼承自 flax.linen.Module) 將 CodeGemma 模型例項化為 model_2b

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

使用 gemma.sampler.Sampler 建立 sampler。它會使用 CodeGemma 模型查核點和剖析器。

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

建立一些變數來代表填入中間 (FIM) 符記,並建立一些輔助函式來格式化提示和產生的輸出內容。

舉例來說,請看下列程式碼:

def function(string):
assert function('asdf') == 'fdsa'

我們想填入 function,讓斷言保留 True。在這種情況下,前置字串會是:

"def function(string):\n"

後置字串則為:

"assert function('asdf') == 'fdsa'"

然後,我們將這段文字格式化為提示,格式為「前置字串-後置字串-中間」(需要填入內容的中間部分一律會出現在提示結尾):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

建立提示並執行推論。指定前置字串 before 和後置字串 after,並使用輔助函式 format_completion prompt 產生格式化的提示。

您可以調整 total_generation_steps (產生回應時執行的步驟數量,這個範例使用 100 來保留主機記憶體)。

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

瞭解詳情