เราขอแนะนำ CodeGemma ซึ่งเป็นคอลเล็กชันโมเดลโค้ดแบบเปิดที่อิงตามโมเดล Gemma ของ Google DeepMind (Gemma Team et al., 2024) CodeGemma เป็นกลุ่มโมเดลแบบเปิดที่ทันสมัยและน้ำหนักเบาซึ่งสร้างขึ้นจากงานวิจัยและเทคโนโลยีเดียวกับที่ใช้สร้างโมเดล Gemini
ต่อจากโมเดล Gemma ที่ผ่านการฝึกล่วงหน้า โมเดล CodeGemma ได้รับการฝึกฝนเพิ่มเติมด้วยโทเค็นโค้ดมากกว่า 500-1,000 พันล้านรายการโดยใช้สถาปัตยกรรมเดียวกับตระกูลโมเดล Gemma ด้วยเหตุนี้ โมเดล CodeGemma จึงมีประสิทธิภาพโค้ดที่ล้ำสมัยทั้งในงานการสร้างและดำเนินการกับโค้ด โดยยังคงรักษาทักษะการเข้าใจและการหาเหตุผลที่ยอดเยี่ยมไว้ได้ในระดับที่กว้าง
CodeGemma มี 3 รูปแบบ ได้แก่
- โมเดลที่ผ่านการฝึกล่วงหน้าด้วยโค้ด 7 พันล้านรายการ
- โมเดลโค้ดที่ปรับแต่งคำสั่ง 7B
- โมเดล 2B ที่ฝึกมาโดยเฉพาะสำหรับการเติมโค้ดและการสร้างแบบปลายเปิด
คู่มือนี้จะแนะนำวิธีใช้โมเดล CodeGemma กับ Flax สำหรับงานเติมโค้ด
ตั้งค่า
1. ตั้งค่าการเข้าถึง Kaggle สําหรับ CodeGemma
หากต้องการทําตามบทแนะนํานี้ให้เสร็จสมบูรณ์ ก่อนอื่นคุณต้องทําตามวิธีการตั้งค่าที่หัวข้อการตั้งค่า Gemma ซึ่งจะแสดงวิธีทําสิ่งต่อไปนี้
- รับสิทธิ์เข้าถึง CodeGemma ใน kaggle.com
- เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอ (GPU T4 มีหน่วยความจําไม่เพียงพอ ให้ใช้ TPU v2 แทน) เพื่อเรียกใช้โมเดล CodeGemma
- สร้างและกําหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle
หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปเพื่อตั้งค่าตัวแปรสภาพแวดล้อมสําหรับสภาพแวดล้อม Colab
2. ตั้งค่าตัวแปรสภาพแวดล้อม
ตั้งค่าตัวแปรสภาพแวดล้อมสําหรับ KAGGLE_USERNAME
และ KAGGLE_KEY
เมื่อได้รับข้อความ "ให้สิทธิ์เข้าถึงไหม" ให้ยอมรับการให้สิทธิ์เข้าถึงข้อมูลลับ
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. ติดตั้งไลบรารี gemma
ปัจจุบันการเร่งฮาร์ดแวร์ของ Colab แบบไม่มีค่าใช้จ่ายinsufficientที่จะเรียกใช้สมุดบันทึกนี้ หากคุณใช้ Colab แบบชําระเงินตามการใช้งานหรือ Colab Pro ให้คลิกแก้ไข > การตั้งค่าโน้ตบุ๊ก > เลือก GPU A100 > บันทึกเพื่อเปิดใช้การเร่งด้วยฮาร์ดแวร์
ขั้นตอนต่อไป คุณต้องติดตั้งไลบรารี gemma
ของ Google DeepMind จาก github.com/google-deepmind/gemma
หากได้รับข้อผิดพลาดเกี่ยวกับ "เครื่องมือแก้ไขข้อกำหนดของ pip" โดยทั่วไปแล้วคุณก็ไม่ต้องสนใจ
pip install -q git+https://github.com/google-deepmind/gemma.git
4. นำเข้าไลบรารี
โน้ตบุ๊กนี้ใช้ Gemma (ซึ่งใช้ Flax เพื่อสร้างเลเยอร์เครือข่ายประสาท) และ SentencePiece (สําหรับการแยกออกเป็นโทเค็น)
import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm
โหลดโมเดล CodeGemma
โหลดโมเดล CodeGemma ด้วย kagglehub.model_download
ซึ่งใช้อาร์กิวเมนต์ 3 รายการ ได้แก่
handle
: แฮนเดิลโมเดลจาก Kagglepath
: (สตริงที่ไม่บังคับ) เส้นทางในเครื่องforce_download
: (บูลีนที่ไม่บังคับ) บังคับให้ดาวน์โหลดโมเดลอีกครั้ง
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
ตรวจสอบตําแหน่งของน้ำหนักโมเดลและตัวแยกวิเคราะห์ แล้วตั้งค่าตัวแปรเส้นทาง ไดเรกทอรีตัวแยกวิเคราะห์จะอยู่ในไดเรกทอรีหลักที่คุณดาวน์โหลดโมเดล ส่วนน้ำหนักโมเดลจะอยู่ในไดเรกทอรีย่อย เช่น
- ไฟล์ตัวแยกวิเคราะห์
spm.model
จะอยู่ใน/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- จุดตรวจของโมเดลจะอยู่ใน
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
ทำการสุ่มตัวอย่าง/อนุมาน
โหลดและจัดรูปแบบจุดตรวจสอบโมเดล CodeGemma ด้วยเมธอด gemma.params.load_and_format_params
params = params_lib.load_and_format_params(CKPT_PATH)
โหลดตัวแยกวิเคราะห์ CodeGemma ซึ่งสร้างขึ้นโดยใช้ sentencepiece.SentencePieceProcessor
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
หากต้องการโหลดการกําหนดค่าที่ถูกต้องจากจุดตรวจสอบโมเดล CodeGemma โดยอัตโนมัติ ให้ใช้ gemma.deprecated.transformer.TransformerConfig
อาร์กิวเมนต์ cache_size
คือจํานวนขั้นตอนเวลาในแคช Transformer
ของ CodeGemma หลังจากนั้น สร้างอินสแตนซ์ของโมเดล CodeGemma เป็น model_2b
ด้วย gemma.deprecated.transformer.Transformer
(ซึ่งรับค่ามาจาก flax.linen.Module
)
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
สร้าง sampler
ด้วย gemma.sampler.Sampler
โดยใช้จุดตรวจสอบโมเดลและตัวแยกวิเคราะห์ของ CodeGemma
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
สร้างตัวแปรเพื่อแสดงโทเค็น "เติมคำในช่องกลาง" (FIM) และสร้างฟังก์ชันตัวช่วยเพื่อจัดรูปแบบพรอมต์และเอาต์พุตที่สร้างขึ้น
ตัวอย่างเช่น ลองดูโค้ดต่อไปนี้
def function(string):
assert function('asdf') == 'fdsa'
เราต้องการกรอก function
เพื่อให้การกล่าวอ้างเป็นจริง True
ในกรณีนี้ คำนำหน้าจะเป็น
"def function(string):\n"
ส่วนส่วนต่อท้ายจะเป็น
"assert function('asdf') == 'fdsa'"
จากนั้นเราจะจัดรูปแบบพรอมต์เป็น PREFIX-SUFFIX-MIDDLE (ส่วนกลางที่ต้องกรอกจะอยู่ท้ายพรอมต์เสมอ) ดังนี้
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
สร้างพรอมต์และทำการอนุมาน ระบุข้อความนำหน้า before
และข้อความต่อท้าย after
แล้วสร้างพรอมต์ที่จัดรูปแบบโดยใช้ฟังก์ชันตัวช่วย format_completion prompt
คุณสามารถปรับแต่ง total_generation_steps
(จํานวนขั้นตอนที่ดำเนินการเมื่อสร้างการตอบกลับ ตัวอย่างนี้ใช้ 100
เพื่อประหยัดหน่วยความจําของโฮสต์)
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
ดูข้อมูลเพิ่มเติม
- ดูข้อมูลเพิ่มเติมเกี่ยวกับไลบรารี
gemma
ของ Google DeepMind ใน GitHub ซึ่งมี docstring ของโมดูลที่คุณใช้ในบทแนะนำนี้ เช่นgemma.params
,gemma.deprecated.transformer
และgemma.sampler
- ไลบรารีต่อไปนี้มีเว็บไซต์เอกสารประกอบของตนเอง ได้แก่ JAX หลัก, Flax และ Orbax
- ดูเอกสารประกอบเกี่ยวกับ
sentencepiece
tokenizer/detokenizer ได้ที่sentencepiece
ที่เก็บ GitHub ของ Google - ดูเอกสารประกอบเกี่ยวกับ
kagglehub
ได้ที่README.md
ในkagglehub
GitHub repo ของ Kaggle - ดูวิธีใช้โมเดล Gemma กับ Google Cloud Vertex AI
- หากคุณใช้ TPU ของ Google Cloud (v3-8 ขึ้นไป) โปรดอัปเดตเป็นแพ็กเกจ
jax[tpu]
เวอร์ชันล่าสุด (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
) รีสตาร์ทรันไทม์ และตรวจสอบว่าเวอร์ชันjax
และjaxlib
ตรงกัน (!pip list | grep jax
) วิธีนี้จะช่วยป้องกันRuntimeError
ที่อาจเกิดขึ้นเนื่องจากเวอร์ชันjaxlib
และjax
ไม่ตรงกัน ดูวิธีการติดตั้ง JAX เพิ่มเติมได้ในเอกสาร JAX