כוונון מבוזר עם Gemma באמצעות Keras

לצפייה ב-ai.google.dev הפעלה ב-Kaggle פתיחה ב-Vertex AI הצגת המקור ב-GitHub

סקירה כללית

ג'מה היא משפחה של מודלים קלילים וחדשניים פתוחים, שנבנו ממחקר וטכנולוגיה ששימשו ליצירת דגמים של Google Gemini. אפשר לשפר את Gemma לצרכים ספציפיים. אבל מודלים גדולים של שפה, כמו Gemma, יכולים להיות גדולים מאוד, וחלק מהם לא מתאימים למאיץ שירה כדי לבצע כוונון עדין. במקרה הזה יש שתי גישות כלליות לכוונון עדין שלהן:

  1. פרמטר יעיל כוונון עדין (PEFT), שמטרתו להקטין את גודל המודל האפקטיבי על ידי הקביעה של דיוק מסוים. LoRA משתייכת לקטגוריה הזו, והמדריך לFine-כוונון Gemma של Keras באמצעות LoRA מדגים כיצד לשפר את מודל Gemma 2B gemma_2b_en עם LoRA באמצעות KerasNLP ב-GPU יחיד.
  2. כוונון עדין של פרמטרים מלאים עם מקביליות של מודל. המקבילה של מודל מחלקת את המשקל של מודל יחיד בין מכשירים שונים ומאפשרת התאמה לעומס (אופקי). מידע נוסף על הדרכה מבוזרת זמין במדריך של Keras.

במדריך הזה תלמדו איך להשתמש ב-Keras עם קצה עורפי של JAX כדי לשפר את מודל Gemma 7B בעזרת LoRA והדרכה מבוזרת של מודל פרליזם ביחידת עיבוד Tensor (TPU) של Google. לתשומת ליבכם: ניתן להשבית את LoRA במדריך הזה, כדי לבצע כוונון של הפרמטר המלא באופן איטי ומדויק יותר.

שימוש במאיצים

טכנית, ניתן להשתמש ב-TPU או ב-GPU במדריך הזה.

הערות על סביבות TPU

ל-Google יש 3 מוצרים שמספקים TPU:

  • Colab מספק את TPU v2, שאינו מספיק למדריך הזה.
  • Kaggle מציע את TPU v3 בחינם, והוא מתאים למדריך הזה.
  • ב-Cloud TPU יש תמיכה ב-TPU v3 ובדורות חדשים יותר. אחת הדרכים להגדיר מעקב המרות היא:
    1. יצירת TPU VM חדש
    2. מגדירים העברה ליציאת SSH עבור יציאת השרת של Jupyter
    3. מתקינים את Jupyter ומפעילים את Jupyter ב-TPU VM, ואז מתחברים ל-Colab באמצעות האפשרות 'התחברות לסביבת זמן ריצה מקומית'.

הערות לגבי הגדרה של מעבדי GPU מרובים

המדריך הזה מתמקד בתרחיש לדוגמה של TPU, אבל תוכלו להתאים אותו בקלות לצרכים שלכם אם יש לכם מכונה עם כמה GPU.

אם אתם מעדיפים לעבוד באמצעות Colab, תוכלו גם להקצות VM עם כמה GPU ל-Colab ישירות דרך האפשרות 'התחברות ל-VM בהתאמה אישית של GCE' בתפריט של Colab Connect.

כאן נתמקד בשימוש ב-TPU החינמי של Kaggle.

לפני שמתחילים

פרטי הכניסה של Kaggle

המודלים של Gemma מתארחים ב-Kaggle. כדי להשתמש ב-Gemma, צריך לבקש גישה ב-Kaggle:

  • כניסה או הרשמה בכתובת kaggle.com
  • פותחים את כרטיס המודל של Gemma ובוחרים באפשרות 'Request Access' (בקשת גישה).
  • ממלאים את טופס ההסכמה ומאשרים את התנאים וההגבלות

לאחר מכן, כדי להשתמש ב-Kaggle API, יוצרים אסימון API:

  • פותחים את הגדרות Kaggle.
  • בוחרים באפשרות "יצירת אסימון חדש".
  • מתבצעת הורדה של קובץ kaggle.json. הוא מכיל את פרטי הכניסה של Kaggle

מריצים את התא הבא ומזינים את פרטי הכניסה של Kaggle כשמוצגת בקשה לעשות זאת.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

דרך חלופית היא להגדיר את KAGGLE_USERNAME ו-KAGGLE_KEY בסביבה שלך, אם הפקודה kagglehub.login() לא מתאימה לך.

התקנה

התקן את Keras ואת KerasNLP עם המודל של Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

הגדרת הקצה העורפי של Keras JAX

ייבא JAX והרצת בדיקת שפיות ב-TPU. ב-Kaggle יש מכשירי TPUv3-8 עם 8 ליבות TPU עם זיכרון בנפח 16GB לכל אחד.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

טעינת המודל

import keras
import keras_nlp

הערות על אימון דיוק מעורב במעבדי NVIDIA GPU

באימון על יחידות GPU של NVIDIA, ניתן להשתמש בדיוק מעורב (keras.mixed_precision.set_global_policy('mixed_bfloat16')) כדי להאיץ את האימון עם השפעה מינימלית על איכות האימון. ברוב המקרים, מומלץ להפעיל את ההגדרה 'דיוק מעורב', כי היא חוסכת גם בזיכרון וגם בזמן. עם זאת, חשוב להיות מודעים לכך שבאצווה קטנה, היא יכולה להגדיל את השימוש בזיכרון פי 1.5 (המשקולות ייטענו פעמיים, ברמת דיוק חצייה ובדיוק מלא).

בכל מקרה, דיוק מוחלט (keras.config.set_floatx("bfloat16")) יפעל ויחסוך בזיכרון, אבל דיוק מעורב לא רלוונטי.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

כדי לטעון את המודל עם המשקולות והמחוונים שמחולקים בין TPUs, קודם צריך ליצור DeviceMesh חדש. DeviceMesh מייצג אוסף של מכשירי חומרה שהוגדרו לחישוב מבוזר, והושק ב-Keras 3 כחלק מממשק ה-API להפצה המאוחד.

ממשק ה-API להפצה מאפשר מקבילה של נתונים ומודלים, וכך מאפשר התאמה לעומס (scaling) של מודלים של למידה עמוקה במספר מאיצים ומארחים. הוא ממנף את ה-framework הבסיסי (למשל JAX) כדי להפיץ את התוכנית ואת ה-tensors בהתאם להוראות הפיצול באמצעות הליך שנקרא 'תוכנית יחידה, הרחבת נתונים (SPMD)'. תוכלו למצוא פרטים נוספים במדריך החדש לממשקי API להפצה של Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

השדה LayoutMap מ-API להפצה מציין איך צריך לפצל או לשכפל את המשקלים ואת רכיבי ה-tensor באמצעות מפתחות המחרוזת, למשל token_embedding/embeddings שבהמשך. ההתייחסות אליהם היא כמו ביטוי רגולרי (regex) כדי להתאים לנתיבי tensor. רכיבי tensor תואמים מופרדים במידות המודל (8 TPUs); אחרים ישוכפלו באופן מלא.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel מאפשר לך לפצל משקולות או גופי הפעלה של מודלים בכל המכשירים ב-DeviceMesh. במקרה הזה, חלק מהמשקולות של דגם Gemma 7B מחולקות בין 8 שבבים של TPU בהתאם ל-layout_map שהוגדר למעלה. עכשיו טוענים את המודל באופן מבוזר.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

כעת מוודאים שהמודל חולק למחיצות (partitioning) בצורה נכונה. ניקח לדוגמה את decoder_block_1.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

מסקנות לפני כוונון

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

המודל יוצר רשימה של סרטי קומדיה מעולים משנות ה-90 לצפייה. עכשיו אנחנו משפרים את המודל של Gemma כדי לשנות את סגנון הפלט.

כוונון עדין עם IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

מבצעים כוונון עדין באמצעות התאמה עם דירוג נמוך (LoRA). שיטת LoRA היא שיטת כוונון, שמפחיתה באופן משמעותי את מספר הפרמטרים שניתן לאמן במשימות downstream, על ידי הקפאת המשקל המלא של המודל והכנסת מספר קטן יותר של משקולות חדשות שניתן לאמן למודל. בעיקרון, כדי לאמן את המטריצות עם המשקל המלא הגדול יותר ב-2 מטריצות קטנות יותר בדירוג נמוך, AxB. השיטה הזו מאפשרת אימון מהיר ויעיל יותר של הזיכרון.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

חשוב לשים לב שהפעלת LoRA מפחיתה באופן משמעותי את מספר הפרמטרים שניתנים לאימון, מ-7 מיליארד ל-11 מיליון בלבד.

מסקנות לאחר כוונון עדין

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

אחרי השיפור, המודל למד את הסגנון של ביקורות על סרטים, ועכשיו הוא מייצר פלט בסגנון הזה בהקשר של סרטי קומדיה משנות ה-90.

המאמרים הבאים

במדריך הזה למדתם איך להשתמש בקצה העורפי של KerasNLP JAX לכוונון עדין של מודל Gemma במערך הנתונים מ-IMDb באופן מבוזר במעבדי ה-TPU העוצמתיים. הנה כמה הצעות לנושאים נוספים שכדאי ללמוד: