גמה ב-PyTorch

הצגה בכתובת ai.google.dev הרצה ב-Google Colab הצגת המקור ב-GitHub

זוהי הדגמה מהירה של הרצת מסקנות Gemma ב-PyTorch. לפרטים נוספים, תוכלו לעיין כאן במאגר GitHub של ההטמעה הרשמית של PyTorch.

שימו לב:

  • זמן הריצה של Python ל-CPU ב-Colab בחינם וזמן הריצה של Python ל-GPU T4 מספיקים להרצת המודלים של Gemma 2B ומודלים של 7B int8 quantized.
  • אם אתם משתמשים בתרחישים מתקדמים לשימוש במעבדי GPU או TPU אחרים, צריך לעיין ב-README.md במאגר הרשמי.

1. הגדרת גישה ל-Kaggle עבור Gemma

כדי להשלים את המדריך הזה, קודם צריך לפעול לפי הוראות ההגדרה במאמר הגדרת Gemma, שמוסבר בהן איך לבצע את הפעולות הבאות:

  • ניתן לקבל גישה ל-Gemma בכתובת kaggle.com.
  • בוחרים סביבת זמן ריצה ב-Colab עם מספיק משאבים להרצת מודל Gemma.
  • יצירת שם משתמש ומפתח API ב-Kaggle והגדרתם.

אחרי שתסיימו את הגדרת Gemma, תוכלו לעבור לקטע הבא שבו תגדירו משתני סביבה לסביבת Colab.

2. הגדרה של משתני סביבה

מגדירים את משתני הסביבה KAGGLE_USERNAME ו-KAGGLE_KEY. כשמוצגת ההודעה "לתת גישה?", מאשרים לתת גישה סודית.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

התקנת יחסי תלות

pip install -q -U torch immutabledict sentencepiece

הורדת משקולות המודל

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

הורדת הטמעת המודל

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

הגדרת המודל

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

הרצת היסק

לפניכם דוגמאות ליצירה במצב צ'אט וליצירה עם מספר בקשות.

מודלי Gemma שהותאמו להוראות הוכשרו באמצעות פורמט ספציפי שמוסיף הערות למידע נוסף לדוגמאות של התאמת הוראות, גם במהלך האימון וגם במהלך ההסקה. ההערות (1) מציינות את התפקידים בשיחה, וגם (2) מסמנות את התורות בשיחה.

אסימוני ההערות הרלוונטיים הם:

  • user: תור המשתמש
  • model: סיבוב המודל
  • <start_of_turn>: תחילת תורו של הדובר
  • <end_of_turn><eos>: סוף תיבת הדו-שיח

מידע נוסף על עיצוב הנחיות למודלים של Gemma שכווננו לפי הוראות זמין כאן.

קטע הקוד הבא מדגים איך לעצב הנחיה למודל Gemma שהותאמה להוראות באמצעות תבניות צ'אט של משתמשים ומודלים בשיחה עם כמה תורנויות.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

מידע נוסף

עכשיו, אחרי שלמדתם איך להשתמש ב-Gemma ב-Pytorch, תוכלו לבדוק את הדברים הרבים האחרים ש-Gemma יכולה לעשות בכתובת ai.google.dev/gemma. כדאי לעיין גם במקורות המידע שקשורים לנושא: