ضبط موزّع بالتعاون مع "جيما" باستخدام Keras

العرض على ai.google.dev التنفيذ في Google Colab التنفيذ في Kaggle الفتح في Vertex AI عرض المصدر على GitHub

نظرة عامة

"جيما" هي مجموعة من النماذج المتطوّرة والخفيفة المتاحة للجميع، والتي تم إنشاؤها استنادًا إلى الأبحاث والتكنولوجيا المستخدَمة لإنشاء نماذج Google Gemini. يمكن تحسين Gemma بشكل أكبر لتلبية احتياجات معيّنة. ولكن يمكن أن تكون النماذج اللغوية الكبيرة، مثل Gemma، كبيرة جدًا في الحجم وقد لا يناسب بعضها استخدام مسرع واحد لإجراء التحسينات. في هذه الحالة، هناك نهجان عامان لضبطهما:

  1. ميزة "التحسين الدقيق والفعّال للمَعلمات" (PEFT)، التي تهدف إلى تصغير حجم النموذج الفعّال من خلال التضحية ببعض الدقّة يندرج LoRA ضمن هذه الفئة، ويوضّح الدليل التعليمي تحسين نماذج Gemma في Keras باستخدام LoRA كيفية تحسين نموذج Gemma 2B gemma_2b_en باستخدام LoRA باستخدام KerasNLP على وحدة معالجة رسومات واحدة.
  2. تحسين كامل للمَعلمات باستخدام التوازي في النماذج يوزّع "التوازي في النماذج" أوزان نموذج واحد على أجهزة متعددة ويفعّل التوسّع الأفقي. يمكنك الاطّلاع على مزيد من المعلومات عن التدريب الموزّع في دليل Keras هذا.

يرشدك هذا الدليل التعليمي إلى استخدام Keras مع الخلفية JAX لتحسين نموذج Gemma 7B باستخدام LoRA والتدريب الموزّع على مستوى النماذج المشابهة على وحدة معالجة Tensor (TPU) من Google. لاحظ أنه يمكن إيقاف LoRA في هذا البرنامج التعليمي لضبط المعلمة الكاملة بشكل أبطأ وأكثر دقة.

استخدام المسرّعات

من الناحية الفنية، يمكنك استخدام وحدة معالجة الموتّرات أو وحدة معالجة الرسومات في هذا البرنامج التعليمي.

ملاحظات حول بيئات وحدات معالجة الموتّرات

لدى Google 3 منتجات توفّر وحدات معالجة TPU:

  • يوفّر Colab الإصدار الثاني من TPU مجانًا، وهو ما يكفي لهذا الدليل التوجيهي.
  • تقدّم Kaggle الإصدار 3 من وحدة معالجة الموتّرات TPU مجانًا، ويمكن استخدامها أيضًا في هذا الدليل التعليمي.
  • تتوفّر في Cloud TPU وحدة معالجة الموتّرات (TPU) الإصدار 3 والأجيال الأحدث. في ما يلي إحدى طرق إعداده:
    1. إنشاء جهاز افتراضي TPU جديد
    2. إعداد إعادة توجيه منفذ SSH لخادم Jupyter المقصود
    3. ثبِّت Jupyter وابدأ تشغيله على جهاز TPU الافتراضي، ثم اتصل بـ Colab من خلال "الاتصال ببيئة تشغيل على الجهاز".

ملاحظات حول إعداد وحدات معالجة الرسومات المتعدّدة

على الرغم من أنّ هذا الدليل التعليمي يركّز على حالة استخدام وحدات معالجة النطاق الفائق (TPU)، يمكنك بسهولة تكييفه لتلبية احتياجاتك إذا كان لديك جهاز مزوّد بعدة وحدات معالجة رسومات.

إذا كنت تفضّل العمل من خلال Colab، يمكنك أيضًا توفير جهاز افتراضي مزوّد بعدة وحدات معالجة رسومات لتطبيق Colab مباشرةً من خلال "الاتصال بجهاز Google Compute Engine افتراضي مخصّص" في قائمة Colab Connect.

سنركّز على استخدام وحدة معالجة الموتّرات المجانية من Kaggle هنا.

قبل البدء

بيانات اعتماد Kaggle

تستضيف منصة Kaggle نماذج Gemma. لاستخدام Gemma، يُرجى طلب إذن الوصول على Kaggle:

  • سجِّل الدخول أو سجِّل على kaggle.com.
  • افتح بطاقة طراز Gemma واختَر "طلب الوصول".
  • إكمال نموذج الموافقة وقبول الأحكام والشروط

بعد ذلك، لاستخدام Kaggle API، أنشئ رمزًا مميّزًا لواجهة برمجة التطبيقات:

  • افتح إعدادات Kaggle
  • اختَر "إنشاء رمز مميّز جديد".
  • يتم تنزيل ملف kaggle.json. يحتوي على بيانات اعتمادك في Kaggle

شغِّل الخلية التالية وأدخِل بيانات اعتمادك على Kaggle عندما يُطلب منك ذلك.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

هناك طريقة بديلة وهي ضبط KAGGLE_USERNAME وKAGGLE_KEY في بيئتك إذا لم تعمل لك دالة kagglehub.login()‎.

تثبيت

ثبِّت Keras وKerasNLP باستخدام نموذج Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

إعداد الخلفية في Keras JAX

استورِد حزمة JAX ونفِّذ عملية التحقّق من الصحة على TPU. تقدم Kaggle أجهزة TPUv3-8 التي تحتوي على 8 نوى TPU مع ذاكرة تبلغ 16 غيغابايت لكل منها.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

تحميل النموذج

import keras
import keras_nlp

ملاحظات حول التدريب باستخدام الدقة المختلطة على وحدات معالجة الرسومات من NVIDIA

عند التدريب على وحدات معالجة الرسومات NVIDIA، يمكن استخدام الدقة المختلطة (keras.mixed_precision.set_global_policy('mixed_bfloat16')) لتسريع التدريب بأقل تأثير ممكن في جودة التدريب. في معظم الحالات، يُنصح بتفعيل الدقة المختلطة لأنّها توفّر كلّ من الذاكرة والوقت. ومع ذلك، يُرجى العِلم أنّه عند استخدام أحجام صغيرة للمجموعات، يمكن أن يؤدي ذلك إلى زيادة استخدام الذاكرة بمقدار 1.5 مرة (سيتم تحميل الأوزان مرتين، بنصف الدقة والدقة الكاملة).

بالنسبة إلى الاستنتاج، سيعمل النصف الدقيق (keras.config.set_floatx("bfloat16")) ويحفظ الذاكرة بينما لا يكون الدقّة المختلطة قابلة للتطبيق.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

لتحميل النموذج باستخدام الأوزان والعناصر المصفوفة الموزّعة على وحدات TPU، عليك أولاً إنشاء DeviceMesh جديد. يمثّل DeviceMesh مجموعة من الأجهزة التي تم ضبطها للحساب الموزّع، وقد تم تقديمها في Keras 3 كجزء من واجهة برمجة التطبيقات الموحّدة للتوزيع.

تتيح واجهة برمجة التطبيقات للتوزيع إمكانية موازاة البيانات والنماذج، ما يسمح بتوسيع نطاق نماذج التعلم المتعمّق بفعالية على برامج مسرِّعات ومضيفات متعددة. ويستفيد من الإطار الأساسي (مثل JAX) لتوزيع البرنامج والعناصر المصفوفة وفقًا لتوجيهات التجزئة من خلال إجراء يُعرف باسم توسيع برنامج واحد وبيانات متعددة (SPMD). يمكنك الاطّلاع على مزيد من التفاصيل في دليل واجهة برمجة التطبيقات لتوزيع Keras 3 الجديد.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

يحدِّد LayoutMap من واجهة برمجة التطبيقات للتوزيع كيفية تقسيم أو نسخ الأوزان والعناصر المصغّرة باستخدام مفاتيح السلاسل، على سبيل المثال، token_embedding/embeddings أدناه، والتي يتم التعامل معها مثل التعبير العادي لمطابقة مسارات عناصر المصغّرة. يتم تقسيم مصفوفات Tensor المطابقة حسب سمات النموذج (8 وحدات معالجة TPU)، وسيتم تكرار العناصر الأخرى بالكامل.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

يتيح لك ModelParallel تقسيم أوزان النماذج أو مصفوفات تنشيط الخلايا على جميع الأجهزة في DeviceMesh. في هذه الحالة، يتم تقسيم بعض أوزان نموذج Gemma 7B على 8 شرائح TPU وفقًا layout_map المحدّد أعلاه. الآن، حمِّل النموذج بالطريقة الموزّعة.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

تأكَّد الآن من أنّه تم تقسيم النموذج بشكل صحيح. لنأخذ decoder_block_1 كمثال.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

الاستنتاج قبل التحسين

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

ينشئ النموذج قائمة بأفلام كوميدية رائعة من التسعينيات لمشاهدتها. سنُجري الآن تعديلات على نموذج Gemma لتغيير أسلوب الإخراج.

الضبط الدقيق باستخدام IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

يمكنك إجراء التعديلات الدقيقة باستخدام Low Rank Adaptation (LoRA). ‫LoRA هي تقنية تحسين تقلّل بشكل كبير من عدد المَعلمات القابلة للتدريب للمهام اللاحقة من خلال تجميد الأوزان الكاملة للنموذج وإدخال عدد أقل من الأوزان الجديدة القابلة للتدريب في النموذج. في الأساس، تعيد LoRA تحديد مَعلمات مصفوفات الأوزان الكاملة الأكبر حجمًا باستخدام مصفوفتَين أصغر حجمًا منخفضة الترتيب AxB للتدريب، وتُعدّ هذه التقنية أسرع بكثير في التدريب وأكثر فعالية في استخدام الذاكرة.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

يُرجى العِلم أنّ تفعيل LoRA يقلل عدد المَعلمات القابلة للتدريب بشكل كبير، من 7 مليارات إلى 11 مليونًا فقط.

الاستنتاج بعد الضبط

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

بعد إجراء التعديلات اللازمة، تعلّم النموذج أسلوب مراجعات الأفلام، وهو ينشئ الآن مراجعات بهذا الأسلوب في سياق الأفلام الكوميدية من التسعينيات.

الخطوات التالية

في هذا الدليل التعليمي، تعرّفت على كيفية استخدام KerasNLP JAX backend لتحسين نموذج Gemma على مجموعة بيانات IMDb بطريقة موزّعة على وحدات معالجة TPU القوية. في ما يلي بعض الاقتراحات حول المواضيع الأخرى التي يمكنك الاطّلاع عليها: