הסקת מסקנות בעזרת RecurrentGemma באמצעות JAX ו-Flatx

להצגה ב-ai.google.dev הפעלה ב-Google Colab פתיחה ב-Vertex AI הצגת המקור ב-GitHub

המדריך הזה מדגים איך לבצע דגימה/הסקה בסיסית באמצעות מודל RecurrentGemma 2B Instruct באמצעות ספריית recurrentgemma של DeepMind שנכתבה באמצעות JAX (ספריית מחשוב עתירת ביצועים), Flax (ספריית רשת נוירונים מבוססת JAX), Orbax וספרייה מבוססת JAX (ספרייה מבוססת JAX).SentencePiece ב-notebook הזה לא נעשה שימוש ישירות ב-Flaser, אבל נעשה שימוש ב-Flatx כדי ליצור את Gemma ו-RecurrentGemma (מודל Griffin).

אפשר להריץ את ה-notebook הזה ב-Google Colab עם GPU מסוג T4 (עוברים אל Edit (עריכה) > Notebook settings (הגדרות מחברת) > בקטע Hardware Aacclerator, בוחרים באפשרות T4 GPU).

הגדרה

בקטעים הבאים מוסבר השלבים להכנת notebook לשימוש במודל RecurrentGemma, כולל גישה למודל, קבלת מפתח API והגדרת זמן הריצה של ה-notebook.

הגדרת גישה של Kaggle ל-Gemma

כדי להשלים את המדריך הזה, קודם כול צריך לפעול לפי הוראות ההגדרה בדומה להגדרה של Gemma, למעט כמה יוצאים מן הכלל:

  • מקבלים גישה אל RecurrentGemma (במקום Gemma) בכתובת kaggle.com.
  • צריך לבחור זמן ריצה של Colab עם מספיק משאבים כדי להריץ את המודל RecurrentGemma.
  • יצירה והגדרה של שם משתמש ומפתח API ב-Kaggle.

אחרי שמשלימים את ההגדרה של RecurrentGemma, עוברים לקטע הבא, שבו מגדירים משתני סביבה לסביבת Colab.

הגדרה של משתני סביבה

הגדרה של משתני סביבה בשביל KAGGLE_USERNAME ו-KAGGLE_KEY. כשמוצגת ההודעה 'הענקת גישה?' הודעות, הסכמה למתן גישה סודית.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

התקנת הספרייה recurrentgemma

ה-notebook הזה מתמקד בשימוש ב-GPU של Colab בחינם. כדי להפעיל שיפור מהירות באמצעות חומרה, לוחצים על עריכה > הגדרות מחברת > בוחרים באפשרות T4 GPU > לוחצים על שמירה.

בשלב הבא צריך להתקין את ספריית recurrentgemma של Google DeepMind מ-github.com/google-deepmind/recurrentgemma. אם מופיעה שגיאה לגבי 'מקודד יחסי התלות של PIP', בדרך כלל אפשר להתעלם ממנה.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

טוענים ומכינים את מודל RecurrentGemma

  1. טוענים את המודל RecurrentGemma באמצעות kagglehub.model_download, שמקבלת שלושה ארגומנטים:
  • handle: נקודת האחיזה של המודל מ-Kaggle
  • path: (מחרוזת אופציונלית) הנתיב המקומי
  • force_download: (ערך בוליאני אופציונלי) מאלץ הורדה מחדש של המודל
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. בודקים את המיקום של משקולות המודל ושל רכיב ההמרה לאסימונים, ואז מגדירים את משתני הנתיב. ספריית האסימון תהיה בספרייה הראשית שבה הורדתם את המודל, ואילו משקלי המודל יהיו בספריית משנה. לדוגמה:
  • הקובץ tokenizer.model יהיה בתיקייה /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • נקודת הביקורת של המודל תהיה ב-/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

ביצוע דגימה/הסקה

  1. טוענים את נקודת הביקורת של מודל RecurrentGemma באמצעות ה-method recurrentgemma.jax.load_parameters. הארגומנט sharding שמוגדר לערך "single_device" טוען את כל הפרמטרים של המודל במכשיר אחד.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. טוענים את כלי ההמרה לאסימונים של מודל RecurrentGemma, שנוצר באמצעות sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. כדי לטעון באופן אוטומטי את ההגדרה הנכונה מנקודת הביקורת של המודל RecurrentGemma, משתמשים ב-recurrentgemma.GriffinConfig.from_flax_params_or_variables. לאחר מכן, יוצרים מופע של המודל Griffin באמצעות recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. יוצרים sampler עם recurrentgemma.jax.Sampler מעל נקודת הביקורת/המשקולות של מודל RecurrentGemma:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. כותבים הנחיה ב-prompt ומבצעים הסקת מסקנות. אפשר לשנות את total_generation_steps (מספר השלבים שבוצעו כשיוצרים תשובה – הדוגמה הזו משתמשת ב-50 כדי לשמר את זיכרון המארח).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

מידע נוסף