הסקת מסקנות עם Gemma באמצעות JAX ו-Flax

לצפייה ב-ai.google.dev הפעלה ב-Google Colab פתיחה ב-Vertex AI הצגת המקור ב-GitHub

סקירה כללית

ג'מה היא משפחה של מודלי שפה גדולים (LLM) קלילים וחדשניים, המבוססים על המחקר והטכנולוגיה של Google DeepMind Gemini. מדריך זה מדגים איך לבצע דגימה/הסקת מסקנות בסיסיות באמצעות מודל ההוראה Gemma 2B באמצעות ספריית gemma של Google DeepMind שנכתבה באמצעות JAX (ספריית מחשוב מספרית בעלת ביצועים גבוהים), Flax (ספריית רשת נוירונים מבוססת JAX), Orbaxize/Orbax (ספרייה מבוססת JAX לאימון אסימונים{/01) ו-SentencePiece למרות שלא משתמשים ב-Flax ישירות ב-notebook הזה, הוא שימש ליצירת Gemma.

ה-notebook הזה יכול לפעול ב-Google Colab עם T4 GPU בחינם (עוברים אל Edit > Notebook settings > בקטע Hardware Accelerator, בוחרים באפשרות T4 GPU).

הגדרה

‫1. הגדרת גישה ל-Kaggle עבור Gemma

כדי להשלים את המדריך הזה, קודם צריך לבצע את הוראות ההגדרה במאמר הגדרת Gemma, שמדגימות איך לבצע את הפעולות הבאות:

  • ניתן לקבל גישה אל Gemma בכתובת kaggle.com.
  • צריך לבחור סביבת זמן ריצה של Colab עם מספיק משאבים להרצת מודל Gemma.
  • יצירה והגדרה של שם משתמש ומפתח API של Kaggle.

אחרי שתסיימו את ההגדרה של Gemma, עברו לקטע הבא שבו מגדירים משתני סביבה לסביבת Colab.

2. הגדרה של משתני סביבה

הגדרת משתני סביבה ל-KAGGLE_USERNAME ול-KAGGLE_KEY. כאשר מופיעות הודעות השגיאה 'האם להעניק גישה?', מסכימים להעניק גישה סודית.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. התקנת הספרייה gemma

ה-notebook הזה מתמקד בשימוש ב-Colab GPU חינמי. כדי להפעיל את שיפור המהירות באמצעות חומרה, לוחצים על Edit (עריכה) > Notebook settings (הגדרות היומן) > בוחרים T4 GPU > Save.

בשלב הבא, עליך להתקין את ספריית gemma של Google DeepMind מ-github.com/google-deepmind/gemma. אם מופיעה שגיאה לגבי 'מקודד התלות של PIP', בדרך כלל אפשר להתעלם ממנה.

pip install -q git+https://github.com/google-deepmind/gemma.git

טעינה והכנה של מודל Gemma

  1. טוענים את מודל Gemma באמצעות kagglehub.model_download, שמקבל שלושה ארגומנטים:
  • handle: הכינוי של המודל מ-Kaggle
  • path: (מחרוזת אופציונלית) הנתיב המקומי
  • force_download: (אופציונלי בוליאני) אילוץ הורדה מחדש של המודל
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. בודקים את המיקום של משקולות המודל ואת כלי ההמרה לאסימונים, ואז מגדירים את משתני הנתיב. ספריית האסימונים תופיע בספרייה הראשית שבה הורדת את המודל, בעוד שמשקולות המודל יהיו בספריית משנה. למשל:
  • קובץ ה-tokenizer.model יהיה ב-/LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • נקודת הביקורת של המודל תהיה ב-/LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

בצעו דגימה/הסקת מסקנות

  1. טוענים את נקודת הביקורת של מודל Gemma ומגדירים אותה באמצעות השיטה gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. טוענים את הכלי ליצירת אסימונים של Gemma, שנבנה באמצעות sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. כדי לטעון באופן אוטומטי את ההגדרה הנכונה מנקודת הביקורת של מודל Gemma, משתמשים ב-gemma.transformer.TransformerConfig. הארגומנט cache_size הוא מספר שלבי הזמן במטמון Transformer של Gemma. לאחר מכן, יוצרים את מודל Gemma כ-transformer עם gemma.transformer.Transformer (שעובר בירושה מ-flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. יוצרים sampler עם gemma.sampler.Sampler בנוסף לנקודת הביקורת/המשקולות של מודל Gemma, ולכלי ההמרה לאסימונים:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. כותבים הנחיה ב-input_batch ומבצעים הסקת מסקנות. אפשר לשנות את total_generation_steps (מספר השלבים שבוצעו בזמן יצירת תגובה). בדוגמה הזו נעשה שימוש ב-100 כדי לשמר את זיכרון המארח).
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (אופציונלי) אם השלמתם את ה-notebook ואתם רוצים לנסות הנחיה אחרת, מריצים את התא הזה כדי לפנות זיכרון. אחר כך אפשר יהיה ליצור שוב מופע של sampler בשלב 3, ולהתאים אישית את ההנחיה בשלב 4 ולהריץ אותה.
del sampler

מידע נוסף