การอนุมานด้วย CodeGemma โดยใช้ JAX และ Flax

ดูใน ai.google.dev เรียกใช้ใน Google Colab ดูแหล่งที่มาใน GitHub

เรานำเสนอ CodeGemma ซึ่งเป็นคอลเล็กชันของโมเดลโค้ดแบบเปิดซึ่งอิงตามโมเดล Gemma ของ Google DeepMind (Gemma Team et al., 2024) CodeGemma เป็นชุดโมเดลเปิดที่ทันสมัยและน้ำหนักเบา สร้างขึ้นจากการวิจัยและเทคโนโลยีเดียวกันกับที่ใช้ในการสร้างโมเดล Gemini

โมเดล CodeGemma จะได้รับการฝึกเพิ่มเติมจากโมเดลที่ฝึกไว้แล้วของ Gemma ด้วยโทเค็นโค้ดหลักมากกว่า 500 ถึง 1 แสนล้านโทเค็น โดยใช้ สถาปัตยกรรมแบบเดียวกับกลุ่มโมเดล Gemma ด้วยเหตุนี้ โมเดล CodeGemma จึงมีประสิทธิภาพของโค้ดที่ล้ำสมัยในทั้ง 2 โค้ดนี้ และงานการสร้าง ไปพร้อมๆ กับรักษาความแข็งแรง ทักษะการทำความเข้าใจและการให้เหตุผลในวงกว้าง

CodeGemma มี 3 ตัวแปร:

  • โมเดลที่ฝึกด้วยโค้ด 7B ล่วงหน้า
  • โมเดลโค้ดที่มีการปรับแต่งตามคำสั่ง 7B
  • โมเดล 2B ที่ได้รับการฝึกมาโดยเฉพาะสำหรับการใส่ข้อมูลโค้ดและการสร้างแบบเปิดกว้าง

คู่มือนี้จะแนะนำให้คุณทราบเกี่ยวกับการใช้โมเดล CodeGemma ร่วมกับ Flax ในการจัดทำโค้ด

ตั้งค่า

1. ตั้งค่าการเข้าถึง Kaggle สำหรับ CodeGemma

หากต้องการจบบทแนะนำนี้ ก่อนอื่นคุณต้องทำตามวิธีการตั้งค่าที่การตั้งค่า Gemma ซึ่งแสดงวิธีดำเนินการต่อไปนี้

  • รับสิทธิ์เข้าถึง CodeGemma ใน kaggle.com
  • เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอ (GPU T4 มีหน่วยความจำไม่เพียงพอ ใช้ TPU v2 แทน) เพื่อเรียกใช้โมเดล CodeGemma
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปซึ่งจะตั้งค่าตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab

2. ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME และ KAGGLE_KEY เมื่อได้รับข้อความแจ้งผ่านข้อความ "ให้สิทธิ์เข้าถึงไหม" ยอมรับการเข้าถึงข้อมูลลับ

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. ติดตั้งไลบรารี gemma

การเร่งฮาร์ดแวร์ฟรีของ Colab ยังไม่เพียงพอในการเรียกใช้สมุดบันทึกนี้ หากคุณใช้ Colab Pay As You Go หรือ Colab Pro ให้คลิกแก้ไข > การตั้งค่าสมุดบันทึก > เลือก A100 GPU > บันทึกเพื่อเปิดใช้การเร่งฮาร์ดแวร์

ถัดไป คุณต้องติดตั้งไลบรารี Google DeepMind gemma จาก github.com/google-deepmind/gemma หากคุณได้รับข้อผิดพลาดเกี่ยวกับ "รีโซลเวอร์ทรัพยากร Dependency ของ PIP" โดยปกติแล้วคุณไม่ต้องสนใจ

pip install -q git+https://github.com/google-deepmind/gemma.git

4. นำเข้าไลบรารี

สมุดบันทึกนี้ใช้ Gemma (ซึ่งใช้ Flax ในการสร้างเลเยอร์โครงข่ายระบบประสาทเทียม) และ SentencePiece (สำหรับการแปลงข้อมูลเป็นโทเค็น)

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

โหลดโมเดล CodeGemma

โหลดโมเดล CodeGemma ด้วย kagglehub.model_download ซึ่งมีอาร์กิวเมนต์ 3 อย่าง ดังนี้

  • handle: แฮนเดิลโมเดลจาก Kaggle
  • path: (สตริงที่ไม่บังคับ) เส้นทางในเครื่อง
  • force_download: (บูลีนที่ไม่บังคับ) บังคับให้ดาวน์โหลดโมเดลอีกครั้ง
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

ตรวจสอบตำแหน่งของน้ำหนักโมเดลและเครื่องมือแปลงข้อมูลเป็นโทเค็น จากนั้นตั้งค่าตัวแปรเส้นทาง ไดเรกทอรีโทเคนไลซ์จะอยู่ในไดเรกทอรีหลักที่คุณดาวน์โหลดโมเดลไป ขณะที่น้ำหนักโมเดลจะอยู่ในไดเรกทอรีย่อย เช่น

  • ไฟล์ Tokenizer ของ spm.model จะอยู่ใน /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • จุดตรวจสอบโมเดลจะอยู่ใน /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

ทำการสุ่มตัวอย่าง/การอนุมาน

โหลดและจัดรูปแบบจุดตรวจสอบโมเดล CodeGemma ด้วยเมธอด gemma.params.load_and_format_params

params = params_lib.load_and_format_params(CKPT_PATH)

โหลดเครื่องมือแปลงข้อมูลเป็นโทเค็น CodeGemma ซึ่งสร้างขึ้นโดยใช้ sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

หากต้องการโหลดการกําหนดค่าที่ถูกต้องโดยอัตโนมัติจากจุดตรวจสอบโมเดล CodeGemma ให้ใช้ gemma.transformer.TransformerConfig อาร์กิวเมนต์ cache_size คือจำนวนขั้นตอนในแคชของ CodeGemma Transformer หลังจากนั้น ให้สร้างอินสแตนซ์โมเดล CodeGemma เป็น model_2b ด้วย gemma.transformer.Transformer (ซึ่งรับค่าจาก flax.linen.Module)

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

สร้าง sampler ด้วย gemma.sampler.Sampler โดยจะใช้จุดตรวจสอบโมเดล CodeGemma และเครื่องมือแปลงข้อมูลเป็นโทเค็น

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

สร้างตัวแปรบางตัวเพื่อแสดงโทเค็นแบบเติมตรงกลาง (fim) และสร้างฟังก์ชันตัวช่วยบางอย่างเพื่อจัดรูปแบบพรอมต์และเอาต์พุตที่สร้างขึ้น

ลองดูโค้ดต่อไปนี้เป็นตัวอย่าง

def function(string):
assert function('asdf') == 'fdsa'

เราต้องการกรอก function เพื่อให้การยืนยันระงับ True ในกรณีนี้ คำนำหน้าจะเป็น

"def function(string):\n"

และคำต่อท้ายจะเป็น

"assert function('asdf') == 'fdsa'"

จากนั้นเราจะจัดรูปแบบพรอมต์นี้เป็นพรอมต์ PREFIX-SUFFIX-MIDDLE (ส่วนตรงกลางที่ต้องเติมจะแสดงที่ส่วนท้ายของข้อความแจ้งเสมอ):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

สร้างพรอมต์และดำเนินการอนุมาน ระบุข้อความนำ before และข้อความคำต่อท้าย after แล้วสร้างพรอมต์ที่มีการจัดรูปแบบโดยใช้ฟังก์ชันตัวช่วย format_completion prompt

คุณสามารถปรับแต่ง total_generation_steps (จำนวนขั้นตอนที่ดำเนินการเมื่อสร้างคำตอบ โดยตัวอย่างนี้ใช้ 100 เพื่อเก็บรักษาหน่วยความจำของโฮสต์)

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

ดูข้อมูลเพิ่มเติม

  • คุณสามารถดูข้อมูลเพิ่มเติมเกี่ยวกับไลบรารี Google DeepMind gemma ใน GitHub ซึ่งมีเอกสารสตริงของโมดูลที่คุณใช้ในบทแนะนำนี้ เช่น gemma.params gemma.transformer และ gemma.sampler
  • ไลบรารีต่อไปนี้มีเว็บไซต์เอกสารประกอบของตนเอง ได้แก่ JAX หลัก, Flax และ Orbax
  • ดูเอกสารประกอบเกี่ยวกับเครื่องมือแปลงข้อมูลเป็นโทเค็น/เครื่องมือถอดรหัสของ sentencepiece ได้ที่ที่เก็บ sentencepiece GitHub ของ Google
  • ดูเอกสารประกอบเกี่ยวกับ kagglehub ได้ที่ README.md ในที่เก็บ GitHub ของ kagglehub ของ Kaggle
  • ดูวิธีใช้โมเดล Gemma กับ Vertex AI ของ Google Cloud
  • หากคุณใช้ Google Cloud TPU (v3-8 ขึ้นไป) โปรดอัปเดตเป็นแพ็กเกจ jax[tpu] ล่าสุด (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html) รีสตาร์ทรันไทม์ และตรวจสอบว่าเวอร์ชัน jax และ jaxlib ตรงกัน (!pip list | grep jax) ซึ่งจะป้องกัน RuntimeError ที่อาจเกิดขึ้นเนื่องจากเวอร์ชัน jaxlib และ jax ไม่ตรงกัน ดูคำแนะนำในการติดตั้ง JAX เพิ่มเติมในเอกสาร JAX