כוונון דגמי Gemma ב-Keras באמצעות LoRA

הצגה בכתובת ai.google.dev הרצה ב-Google Colab פתיחה ב-Vertex AI הצגת המקור ב-GitHub

סקירה כללית

Gemma היא משפחה של מודלים פתוחים וקלים לייצור, שמבוססים על אותם מחקר וטכנולוגיה ששימשו ליצירת המודלים של Gemini.

מודלים גדולים של שפה (LLM) כמו Gemma הוכיחו את היעילות שלהם במגוון משימות של עיבוד שפה טבעית (NLP). מודל שפה גדול (LLM) מאומן מראש על מאגר גדול של טקסט באופן עצמאי. אימון מראש עוזר למודלים מסוג LLM ללמוד ידע לשימוש כללי, כמו קשרים סטטיסטיים בין מילים. לאחר מכן אפשר לבצע כוונון עדין של LLM באמצעות נתונים ספציפיים לדומיין כדי לבצע משימות במורד הזרם (כמו ניתוח סנטימנטים).

מודלים גדולים של שפה הם גדולים מאוד (פרמטרים בסדר גודל של מיליארדים). ברוב האפליקציות לא נדרש כוונון מלא (מעדכן את כל הפרמטרים במודל) כי מערכי נתונים אופייניים של כוונון עדין הם הרבה יותר קטנים יחסית ממערכי הנתונים של האימון שלפני האימון.

התאמת דירוג נמוך (LoRA) היא שיטת כוונון עדינה שמפחיתה באופן משמעותי את מספר הפרמטרים שניתן לאמן למשימות downstream, על ידי הקפאת המשקולות של המודל והכנסת מספר קטן יותר של משקולות חדשות למודל. כך האימון באמצעות LoRA מהיר יותר ויעיל יותר מבחינת שימוש בזיכרון, והוא מייצר משקלים קטנים יותר של מודלים (כמה מאות מגה-בייט), תוך שמירה על איכות הפלט של המודלים.

במדריך הזה תלמדו איך להשתמש ב-KerasNLP כדי לבצע כוונון מדויק של LoRA במודל Gemma 2B באמצעות מערך הנתונים של Databricks Dolly 15k. מערך הנתונים הזה מכיל 15,000 צמדים של הנחיות / תשובות שנוצרו על ידי אדם באיכות גבוהה, שמיועדים במיוחד לכוונון עדין של מודלים גדולים של שפה (LLM).

הגדרה

קבלת גישה ל-Gemma

כדי להשלים את המדריך הזה, קודם צריך להשלים את הוראות ההגדרה במאמר הגדרת Gemma. בהוראות להגדרת Gemma מוסבר איך לבצע את הפעולות הבאות:

  • ניתן לקבל גישה ל-Gemma בכתובת kaggle.com.
  • בוחרים זמן ריצה של Colab עם מספיק משאבים כדי להריץ את המודל Gemma 2B.
  • יצירת שם משתמש ומפתח API ב-Kaggle והגדרתם.

אחרי שתסיימו את הגדרת Gemma, תוכלו לעבור לקטע הבא שבו תגדירו משתני סביבה לסביבת Colab.

בחירת סביבת זמן הריצה

כדי להשלים את המדריך הזה, צריך שיהיה לכם זמן ריצה של Colab עם מספיק משאבים כדי להריץ את מודל Gemma. במקרה הזה, תוכלו להשתמש ב-GPU T4:

  1. בפינה השמאלית העליונה של חלון Colab, בוחרים באפשרות ▾ (אפשרויות חיבור נוספות).
  2. בוחרים באפשרות Change runtime type (שינוי הסוג של סביבת זמן הריצה).
  3. בקטע Hardware accelerator (שיפור המהירות באמצעות חומרה), בוחרים באפשרות T4 GPU.

הגדרת מפתח ה-API

כדי להשתמש ב-Gemma, צריך לספק את שם המשתמש ב-Kaggle ומפתח API של Kaggle.

כדי ליצור מפתח API של Kaggle, עוברים לכרטיסייה Account (חשבון) בפרופיל המשתמש ב-Kaggle ובוחרים באפשרות Create New Token (יצירת אסימון חדש). הפעולה הזו תגרום להורדה של קובץ kaggle.json שמכיל את פרטי הכניסה ל-API.

ב-Colab, בוחרים באפשרות Secrets (🔑) בחלונית הימנית ומוסיפים את שם המשתמש של Kaggle ואת מפתח ה-API של Kaggle. שומרים את שם המשתמש בשם KAGGLE_USERNAME ואת מפתח ה-API בשם KAGGLE_KEY.

הגדרה של משתני סביבה

הגדרה של משתני סביבה בשביל KAGGLE_USERNAME ו-KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

התקנת יחסי תלות

מתקינים את Keras,‏ KerasNLP ויחסי תלות אחרים.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

בחירת קצה עורפי

Keras הוא ממשק API ברמה גבוהה ללמידת עומק במסגרות מרובות, שתוכנן להיות פשוט וקל לשימוש. באמצעות Keras 3, אפשר להריץ תהליכי עבודה באחד משלושת הקצוות העורפיים: TensorFlow,‏ JAX או PyTorch.

במדריך הזה, מגדירים את הקצה העורפי של JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

ייבוא חבילות

מייבאים את Keras ואת KerasNLP.

import keras
import keras_nlp

טעינת מערך נתונים

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

עיבוד מראש של הנתונים. במדריך הזה נעשה שימוש בקבוצת משנה של 1,000 דוגמאות אימון כדי להריץ את המחברות מהר יותר. כדאי להשתמש בנתוני אימון נוספים כדי לשפר את איכות השיפור.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

טעינת מודל

ב-KerasNLP יש הטמעות של ארכיטקטורות מודלים פופולריות. במדריך הזה תלמדו ליצור מודל באמצעות GemmaCausalLM, מודל Gemma מקצה לקצה ליצירת מודלים של שפה גורמית. מודל שפה סיבתי חוזה את האסימון הבא על סמך אסימונים קודמים.

יוצרים את המודל באמצעות השיטה from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

השיטה from_preset יוצרת מודל מארכיטקטורה וממשקלים מוגדרים מראש. בקוד שלמעלה, המחרוזת 'gemma2_2b_en' מציינת את הארכיטקטורה המוגדרת מראש – מודל Gemma עם 2 מיליארד פרמטרים.

הסקת מסקנות לפני כוונון עדין

בקטע הזה תשלחו שאילתה למודל עם הנחיות שונות כדי לראות איך הוא מגיב.

הנחיה ליצירת נסיעה באירופה

שולחים שאילתה למודל כדי לקבל הצעות למה לעשות בנסיעה לאירופה.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

המודל יגיב עם טיפים כלליים לתכנון נסיעה.

הנחיה בנושא פוטוסינתזה ל-ELI5

מבקשים מהמודל להסביר את התהליך של הפוטוסינתזה במונחים פשוטים שילד בן 5 יוכל להבין.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

התשובה של המודל מכילה מילים שיכול להיות שיהיה קשה לילדים להבין, כמו 'כלורופיל'.

כוונון עדין של LoRA

כדי לקבל תשובות טובות יותר מהמודל, אפשר לשפר את המודל באמצעות התאמה של דירוג נמוך (LoRA) באמצעות מערך הנתונים Dolly 15k של Databricks.

דירוג LoRA קובע את המידות של המטריצות הניתנות לאימון שנוספו למשקולות המקוריות של ה-LLM. הוא שולט ברמת ההבעה והדיוק של התאמות הכוונון.

כשהדירוג גבוה יותר, אפשר לבצע שינויים מפורטים יותר, אבל המשמעות היא גם יותר פרמטרים שניתן לאמן. דירוג נמוך יותר מפחית את עומס העיבוד, אבל יכול להיות שההתאמה תהיה פחות מדויקת.

המדריך הזה משתמש בדירוג LoRA של 4. בפועל, כדאי להתחיל עם דירוג קטן יחסית (כמו 4, 8, 16). זוהי שיטה יעילה מבחינה חישובית לניסויים. מארגנים את המודל לפי הדירוג הזה ומעריכים את שיפור הביצועים במשימה. אפשר להעלות את הדירוג בהדרגה בניסיונות הבאים ולבדוק אם זה משפר את הביצועים עוד יותר.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

חשוב לזכור שהפעלת LoRA מפחיתה באופן משמעותי את מספר הפרמטרים שאפשר לאמן (מ-2.6 מיליארד ל-2.9 מיליון).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

הערה לגבי כוונון עדין של דיוק מעורב ב-GPU של NVIDIA

מומלץ להשתמש ברמת דיוק מלאה לצורך כוונון מדויק. כשמבצעים כוונון עדין ב-GPU של NVIDIA, חשוב לזכור שאפשר להשתמש ברמת דיוק משולבת (keras.mixed_precision.set_global_policy('mixed_bfloat16')) כדי לזרז את האימון תוך השפעה מינימלית על איכות האימון. התאמה אישית ברמת דיוק משולבת צורכת יותר זיכרון, ולכן היא שימושית רק ב-GPU גדולים יותר.

להסקה, דיוק חצי (keras.config.set_floatx("bfloat16")) יפעל ויחסוך זיכרון, בעוד שדיוק מעורב לא רלוונטי.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

הסקת מסקנות לאחר שינוי מדויק

אחרי שמשלימים את השיפורים, התשובות יתקבלו בהתאם להוראות שמופיעות בהנחיה.

הודעה לנסיעה לאירופה

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

עכשיו המודל ממליץ על מקומות שכדאי לבקר בהם באירופה.

הנחיה בנושא פוטוסינטזה ל-ELI5

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

המודל עכשיו מסביר פוטוסינתזה בשפה פשוטה יותר.

שימו לב שלמטרות הדגמה, המדריך הזה מכוונן את המודל בקבוצת משנה קטנה של מערך הנתונים לתקופת זמן אחת בלבד עם ערך דירוג LoRA נמוך. כדי לקבל תשובות טובות יותר מהמודל המכוונן, אפשר להתנסות באפשרויות הבאות:

  1. הגדלת הגודל של מערך הנתונים לכוונון מדויק
  2. הדרכה לשלבים נוספים (תקופות)
  3. הגדרת דירוג LoRA גבוה יותר
  4. שינוי ערכי ההיפר-פרמטרים, כמו learning_rate ו-weight_decay.

הסיכום והשלבים הבאים

המדריך הזה עוסק בכוונון עדין של LoRA במודל Gemma באמצעות KerasNLP. לאחר מכן, כדאי לעיין במסמכים הבאים: