כוונון מבוזר עם Gemma באמצעות Keras

להצגה ב-ai.google.dev הפעלה ב-Google Colab הפעלה ב-Kaggle פתיחה ב-Vertex AI הצגת המקור ב-GitHub

סקירה כללית

Gemma היא משפחה של מודלים חד-פעמיים קלילים ופתוחים שמבוססים על מחקר וטכנולוגיה ליצירת מודלים של Google Gemini. אפשר לבצע כוונון עדין של Gemma לצרכים ספציפיים. עם זאת, מודלים גדולים של שפה, כמו Gemma, יכולים להיות גדולים מאוד, וחלקם לא יכולים להתאים למאיץ שירה. במקרה הזה יש שתי גישות כלליות לכוונון:

  1. Parameter Efficient Adjust-Tuning (PEFT), שמטרתו להקטין את גודל המודל האפקטיבי על-ידי ויתור על דיוק מסוים. LoRA מסווג בקטגוריה הזו, והמדריך בנושא שיפור מודלים של Gemma ב-Keras באמצעות LoRA מדגים איך לכוונן את מודל Gemma 2B gemma_2b_en עם LoRA באמצעות KerasNLP ב-GPU יחיד.
  2. כוונון מלא של פרמטרים באמצעות מקבילות של מודל. המקבילות של המודל מפיצה את המשקולות של מודל יחיד במספר מכשירים ומאפשרת מדרגיות אופקית. מידע נוסף על הכשרות מבוזרות זמין במדריך הזה של Keeras.

במדריך הזה נסביר איך להשתמש ב-Keras עם קצה עורפי של JAX כדי לשפר את מודל Gemma 7B באמצעות LoRA והאימון המבוזר של המודלים (Tensor Processing Unit) של Google. לתשומת ליבכם: אפשר להשבית את LoRA במדריך הזה כדי לבצע כוונון איטי ומדויק יותר של פרמטר מלא.

שימוש במאיצים

מבחינה טכנית אפשר להשתמש ב-TPU או ב-GPU למדריך הזה.

הערות על סביבות TPU

ל-Google יש 3 מוצרים שמספקים מערכות TPU:

  • Colab הוא כלי שמספק TPU v2 בחינם, וזה מספיק למדריך הזה.
  • Kaggle מציע TPU v3 בחינם והוא גם עובד במדריך הזה.
  • Cloud TPU מציע TPU v3 ודורות חדשים יותר. אחת מהדרכים להגדיר אותו היא:
    1. יצירת TPU VM חדש
    2. מגדירים העברה ליציאת SSH עבור יציאת השרת של Jupyter שרוצים
    3. מתקינים את Jupyter ומפעילים אותו ב-TPU VM, ואז מתחברים ל-Colab דרך 'Connect to a Local runtime'

הערות לגבי הגדרה של ריבוי GPU

על אף שמדריך זה מתמקד בתרחיש לדוגמה של TPU, תוכל להתאים אותו בקלות לצרכים שלך אם יש לך מכונה מרובת GPU.

אם מעדיפים לעבוד באמצעות Colab, אפשר גם להקצות ל-Colab מכונה וירטואלית מרובת GPU באופן ישיר דרך 'התחברות ל-VM בהתאמה אישית ב-GCE' בתפריט של Colab Connect.

כאן נתמקד בשימוש ב-TPU בחינם מ-Kaggle.

לפני שמתחילים

פרטי הכניסה של Kaggle

מודלים של Gemma מתארחים ב-Kaggle. כדי להשתמש ב-Gemma, צריך לבקש גישה ב-Kaggle:

  • נכנסים לחשבון או נרשמים בכתובת kaggle.com.
  • פותחים את הכרטיס של מודל Gemma ובוחרים באפשרות 'בקשת גישה'
  • ממלאים את טופס ההסכמה ומאשרים את התנאים וההגבלות

לאחר מכן, כדי להשתמש ב-Kaggle API, יוצרים אסימון API:

  • פותחים את הגדרות Kaggle
  • בוחרים באפשרות יצירת אסימון חדש.
  • מתבצעת הורדה של קובץ kaggle.json. הוא מכיל את פרטי הכניסה שלך ב-Kaggle

מריצים את התא הבא ומזינים את פרטי הכניסה של Kaggle כשמופיעה בקשה.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

דרך חלופית היא להגדיר את KAGGLE_USERNAME ואת KAGGLE_KEY בסביבה שלך, אם kagglehub.login() לא עובד לך.

התקנה

מתקינים את Keras ו-KerasNLP באמצעות המודל של Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

הגדרת הקצה העורפי של Keras JAX

תייבאו JAX ותריצו בדיקת אבטחה ל-TPU. Kaggle מציעה מכשירי TPUv3-8 עם 8 ליבות TPU עם זיכרון בנפח 16GB לכל אחד.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

טעינת המודל

import keras
import keras_nlp

הערות על אימון דיוק מעורב במעבדי GPU של NVIDIA

באימון על מעבדי NVIDIA, רמת דיוק מעורבת (keras.mixed_precision.set_global_policy('mixed_bfloat16')) יכולה לשמש להאצת האימון תוך השפעה מינימלית על איכות האימון. ברוב המקרים, מומלץ להפעיל דיוק משולב כי זה חוסך גם זיכרון וגם זמן. עם זאת, חשוב לדעת שבאצווה קטנה ההגדרה יכולה להגדיל את השימוש בזיכרון פי 1.5 (משקולות ייטענו פעמיים, בחצי דיוק ובדיוק מלא).

לצורך הסקת מסקנות, חצי דיוק (keras.config.set_floatx("bfloat16")) יעבוד ויחסוך בזיכרון, אבל דיוק משולב לא רלוונטי.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

כדי לטעון את המודל עם המשקולות והטנזורים שמפוזרים בין מעבדי TPU, קודם צריך ליצור DeviceMesh חדש. DeviceMesh מייצג אוסף של מכשירי חומרה שהוגדרו לחישוב מבוזר והושקו ב-Keras 3 כחלק מה-API המאוחד להפצה.

ה-API של ההפצה מאפשר מקבילה של נתונים ומודלים, ומאפשר התאמה לעומס (scaling) של מודלים של למידה עמוקה באופן יעיל בכמה מאיצים ומארחים. הוא משתמש ב-framework הבסיסי (למשל, JAX) כדי להפיץ את התוכנה ואת הפרמטרים tensors בהתאם להנחיות הפיצול (shard) באמצעות תהליך שנקרא תוכנית יחידה, הרחבת נתונים מרובים (SPMD). פרטים נוספים זמינים במדריך החדש של Keras 3 לספק API.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap מממשק ה-API של ההפצה מציין איך צריך לפצל או לשכפל את המשקולות ואת הפרמטרים Tensoring, באמצעות מפתחות המחרוזת, למשל token_embedding/embeddings שלמטה, שיטופלו כמו ביטוי רגולרי (regex) כדי להתאים לנתיבים של tensor. רכיבי Tensor תואמים מחולקים לפי מידות המודל (8 יחידות TPU). יש לשכפל באופן מלא.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

באמצעות ModelParallel אפשר לפצל את משקולות המודל או את רכיבי ההפעלה של Tensors בין כל הסטיות ב-DeviceMesh. במקרה הזה, חלק מהמשקולות של מודל Gemma 7B מחולקות בין 8 שבבי TPU לפי הערך layout_map שהוגדר למעלה. עכשיו טוענים את המודל באופן מבוזר.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

עכשיו מוודאים שהמודל חולק למחיצות בצורה נכונה. ניקח לדוגמה את decoder_block_1.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

הסקת מסקנות לפני כוונון

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

המודל יוצר רשימה של סרטי קומדיה נהדרים משנות ה-90 של המאה ה-20. עכשיו אנחנו משפרים את המודל של Gemma כדי לשנות את סגנון הפלט.

שיפור ההתאמה באמצעות IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

לבצע כוונון באמצעות התאמה עם דירוג נמוך (LoRA). LoRA היא שיטת כוונון שמצמצמת במידה משמעותית את מספר הפרמטרים שניתן לאמן למשימות במורד הזרם, על ידי הקפאת המשקולות המלאות של המודל והכנסת מספר קטן יותר של משקולות חדשות למודל. בעיקרון, LoRA משנה את הפרמטרים של המטריצות הגדולות יותר במשקל מלא באמצעות 2 מטריצות קטנות יותר בעלות דירוג נמוך (AxB).

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

חשוב לזכור: הפעלת LoRA מפחיתה באופן משמעותי את מספר הפרמטרים שניתנים לאימון – מ-7 מיליארד ל-11 מיליון בלבד.

הסקת מסקנות אחרי כוונון עדין

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

לאחר הכוונון, המודל למד את הסגנון של ביקורות על סרטים ועכשיו הוא מפיק פלט בסגנון הזה בהקשר של סרטי קומדיה משנות ה-90 של המאה ה-20.

המאמרים הבאים

במדריך הזה למדתם איך להשתמש בקצה העורפי של KerasNLP JAX כדי לשפר ולחדד מודל Gemma במערך הנתונים מ-IMDb בצורה מבוזרת בעזרת מערכות ה-TPU העוצמתיות. הנה כמה הצעות למידע נוסף שאפשר ללמוד: