הסקת מסקנות מ-CodeGemma באמצעות JAX ו-Flatx

להצגה ב-ai.google.dev הפעלה ב-Google Colab הצגת המקור ב-GitHub

אנחנו מציגים את CodeGemma, אוסף מודלים של קוד פתוח שמבוססים על מודלים של Gemma של Google DeepMind (Gemma Team et al., 2024). CodeGemma היא משפחה של מודלים חד-פעמיים קלילים ופתוחים שנוצרו על ידי אותו מחקר וטכנולוגיה ששימשו ליצירת המודלים של Gemini.

בהמשך למודלים שעברו אימון מראש של Gemma, המודלים של CodeGemma עוברים אימון על יותר מ-500 עד 1,000 מיליארד אסימונים של בעיקר קוד, את אותן הארכיטקטורות כמו של משפחת המודלים של Gemma. כתוצאה מכך, מודלים של CodeGemma משיגים ביצועים ברמה גבוהה של קוד בשני תהליכי ההשלמה. ויצירת משימות חדשות, תוך שמירה על מיומנויות הבנה והסקת מסקנות בקנה מידה נרחב.

ל-CodeGemma יש 3 וריאציות:

  • מודל שעבר אימון מראש על ידי קוד 7B
  • מודל קוד מכוונן להוראה 7B
  • מודל 2B שאומן במיוחד למילוי קוד ויצירה של קוד פתוח.

במדריך הזה מוסבר איך להשתמש במודל CodeGemma עם Flax במשימת השלמת קוד.

הגדרה

1. הגדרת גישת Kaggle ל-CodeGemma

כדי להשלים את המדריך הזה, קודם כול צריך לפעול לפי הוראות ההגדרה שמפורטות במאמר הגדרת Gemma, שמלמדות איך לבצע את הפעולות הבאות:

  • מקבלים גישה ל-CodeGemma בכתובת kaggle.com.
  • צריך לבחור זמן ריצה של Colab עם מספיק משאבים (ל-GPU של T4 אין מספיק זיכרון, צריך להשתמש במקום זאת ב-TPU v2) כדי להריץ את מודל CodeGemma.
  • יצירה והגדרה של שם משתמש ומפתח API ב-Kaggle.

אחרי שמשלימים את ההגדרה של Gemma, עוברים לקטע הבא, שבו מגדירים משתני סביבה לסביבת Colab.

2. הגדרה של משתני סביבה

הגדרה של משתני סביבה בשביל KAGGLE_USERNAME ו-KAGGLE_KEY. כשמוצגת ההודעה 'הענקת גישה?' הודעות, עליך להסכים להעניק גישה סודית.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. התקנת הספרייה gemma

שיפור המהירות באמצעות חומרה של Colab לא מספיק כרגע כדי להריץ את ה-notebook הזה. אם משתמשים ב-Colab Pay As You Go או ב-Colab Pro, צריך ללחוץ על Edit (עריכה) > הגדרות מחברת > בוחרים באפשרות A100 GPU > כדי להפעיל את שיפור המהירות באמצעות חומרה, לוחצים על שמירה.

בשלב הבא צריך להתקין את ספריית gemma של Google DeepMind מ-github.com/google-deepmind/gemma. אם מופיעה שגיאה לגבי 'מקודד יחסי התלות של PIP', בדרך כלל אפשר להתעלם ממנה.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. ייבוא ספריות

ה-notebook הזה משתמש ב-Gemma (שמשתמשת ב-Flax כדי לבנות את שכבות הרשת הנוירונים שלה), וב-SentencePiece (לצורך יצירת אסימונים).

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

טעינת מודל CodeGemma

טוענים את מודל CodeGemma באמצעות הפונקציה kagglehub.model_download, שמקבלת שלושה ארגומנטים:

  • handle: נקודת האחיזה של המודל מ-Kaggle
  • path: (מחרוזת אופציונלית) הנתיב המקומי
  • force_download: (ערך בוליאני אופציונלי) מאלץ הורדה מחדש של המודל
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

בודקים את המיקום של משקולות המודל ושל רכיב ההמרה לאסימונים, ואז מגדירים את משתני הנתיב. ספריית האסימון תהיה בספרייה הראשית שבה הורדתם את המודל, ואילו משקלי המודל יהיו בספריית משנה. לדוגמה:

  • קובץ האסימון של spm.model יהיה בקובץ /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • נקודת הביקורת של המודל תהיה בעוד /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

ביצוע דגימה/הסקה

טוענים ויוצרים את נקודת הביקורת של מודל CodeGemma באמצעות השיטה gemma.params.load_and_format_params:

params = params_lib.load_and_format_params(CKPT_PATH)

טוענים את רכיב ההמרה של CodeGemma, שנוצר באמצעות sentencepiece.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

כדי לטעון באופן אוטומטי את ההגדרה הנכונה מנקודת הביקורת של מודל CodeGemma, משתמשים ב-gemma.transformer.TransformerConfig. הארגומנט cache_size הוא מספר שלבי הזמן במטמון של CodeGemma Transformer. לאחר מכן, יוצרים מודל CodeGemma כ-model_2b עם gemma.transformer.Transformer (שיורש מ-flax.linen.Module).

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

יצירת sampler באמצעות gemma.sampler.Sampler. היא משתמשת בנקודת הביקורת של מודל CodeGemma ובכלי ההמרה לאסימונים.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

אפשר ליצור כמה משתנים כדי לייצג את אסימוני המילוי האמצעי (fim) וליצור כמה פונקציות מסייעות לעיצוב ההנחיה והפלט שנוצר.

לדוגמה, נסתכל על הקוד הבא:

def function(string):
assert function('asdf') == 'fdsa'

אנחנו רוצים למלא את השדה function כדי שהטענת נכוֹנוּת (assertion) תקיים True. במקרה הזה, הקידומת תהיה:

"def function(string):\n"

והסיומת תהיה:

"assert function('asdf') == 'fdsa'"

לאחר מכן אנחנו מעצבים את הטקסט הזה כהנחיה בתור PREFIX- כותבים-MIDDLE (הקטע האמצעי שצריך למלא תמיד מופיע בסוף ההנחיה):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

יוצרים הנחיה ומסיקים מסקנות. צריך לציין את התחילית before את הטקסט ואת הסיומת after, וליצור את ההנחיה המעוצבת באמצעות פונקציית העזרה format_completion prompt.

אפשר לשנות את total_generation_steps (מספר השלבים שבוצעו כשיוצרים תשובה – הדוגמה הזו משתמשת ב-100 כדי לשמר את זיכרון המארח).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

מידע נוסף

  • אפשר לקבל מידע נוסף על ספריית Google DeepMind gemma ב-GitHub, שמכילה מחרוזות docstring של מודולים שהשתמשתם בהם במדריך הזה, כמו gemma.params, gemma.transformer, וגם gemma.sampler.
  • לספריות הבאות יש אתרי תיעוד משלהן: core JAX , Flax ו-Orbax.
  • למידע נוסף על יצירת אסימונים/detokenizer של sentencepiece, כדאי לעיין במאגר GitHub sentencepiece של Google.
  • למסמכי תיעוד של kagglehub, כדאי לעיין ב-README.md במאגר GitHub של kagglehub.
  • איך משתמשים במודלים של Gemma ב-Google Cloud Vertex AI
  • אם משתמשים במעבדי TPU של Google Cloud (מגרסה 3-8 ואילך), חשוב לעדכן גם לחבילת jax[tpu] האחרונה (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html), להפעיל מחדש את סביבת זמן הריצה ולבדוק שהגרסה של jax ושל jaxlib תואמות (!pip list | grep jax). המצב הזה יכול למנוע את שגיאת ה-RuntimeError שעלולה לקרות בגלל חוסר התאמה בין הגרסאות: jaxlib ו-jax. הוראות נוספות להתקנת JAX זמינות במסמכי JAX.