ضبط موزّع بالتعاون مع "جيما" باستخدام Keras

العرض على ai.google.dev التنفيذ في Kaggle الفتح في Vertex AI الاطّلاع على المصدر على GitHub

نظرة عامة

"جيما" هي عائلة من النماذج المفتوحة والخفيفة والمتطورة التي تم إنشاؤها من الأبحاث والتكنولوجيا المستخدَمة لإنشاء نماذج Google Gemini. ويمكن تحسين جيما لتناسب احتياجات محددة. ولكن النماذج اللغوية الكبيرة، مثل جيما، يمكن أن تكون كبيرة جدًا في الحجم وقد لا يتناسب بعضها مع مسرِّعة صوت للضبط الدقيق. في هذه الحالة، هناك نهجان عامان لضبطهما:

  1. التوليف الدقيق للمعلمة (PEFT)، الذي يسعى إلى تقليل حجم النموذج الفعال من خلال التضحية ببعض الدقة. تقع لورا ضمن هذه الفئة، ويوضح البرنامج التعليمي Fine-tune tune gemma في Keras باستخدام LoRA كيفية تحسين نموذج Gemma 2B gemma_2b_en باستخدام LoRA باستخدام KerasNLP على وحدة معالجة رسومات واحدة.
  2. ضبط المعلمة الكاملة باستخدام توازي النموذج. توزِّع ميزة التوازي في النموذج أوزان نموذج واحد على أجهزة متعددة وتتيح التحجيم الأفقي. يمكنك الاطّلاع على مزيد من المعلومات حول التدريب الموزَّع في دليل Keras هذا.

يرشدك هذا البرنامج التعليمي إلى كيفية استخدام Keras مع خلفية JAX لتحسين نموذج Gemma 7B مع LoRA والتدريب الموزع للنماذج على وحدة معالجة Tensor (TPU) من Google. تجدر الإشارة إلى أنّه يمكن إيقاف LoRA في هذا الدليل التوجيهي لضبط المَعلمة الكاملة بشكل أبطأ ولكن بشكل أكثر دقة.

استخدام مسرِّعات الأعمال

من الناحية الفنية، يمكنك استخدام وحدة معالجة الموتّرات أو وحدة معالجة الرسومات في هذا البرنامج التعليمي.

ملاحظات حول بيئات TPU

تقدّم Google 3 منتجات تقدّم وحدات معالجة الموتّرات:

  • يوفّر تطبيق Colab الإصدار 2 من TPU، وهو غير كافٍ لهذا البرنامج التعليمي.
  • توفّر Kaggle الإصدار 3 من TPU مجانًا ويمكن استخدامها في هذا البرنامج التعليمي.
  • توفّر Cloud TPU الإصدار 3 من TPU والأجيال الأحدث. لإعداد هذا الإجراء، يمكنك اتّباع الخطوات التالية:
    1. إنشاء جهاز افتراضي (VM) جديد
    2. إعداد إعادة توجيه منفذ SSH لمنفذ خادم Jupyter المقصود
    3. ثبِّت Jupyter وابدأ تشغيله على جهاز TPU VM، ثم اربطه بـ Colab من خلال خيار "الاتصال بوقت تشغيل محلي".

ملاحظات حول إعداد وحدات معالجة رسومات متعددة

بالرغم من أنّ هذا الدليل التوجيهي يركّز على حالة استخدام وحدة معالجة الموتّرات، يمكنك تعديلها بسهولة تناسب احتياجاتك إذا كان لديك جهاز يتضمّن وحدة معالجة رسومات متعدّدة.

إذا كنت تفضّل العمل باستخدام Colab، يمكنك أيضًا توفير جهاز افتراضي (VM) يتضمّن وحدات معالجة رسومات متعددة من أجل Colab مباشرةً من خلال "الاتصال بجهاز GCE VM مخصّص" في قائمة Colab Connect.

سنركّز على استخدام وحدة معالجة الموتّرات المجانية من Kaggle هنا.

قبل البدء

بيانات اعتماد Kaggle

تستضيف Kaggle نماذج Gemma. لاستخدام Gemma، اطلب الوصول على Kaggle:

  • تسجيل الدخول أو التسجيل على kaggle.com
  • افتح بطاقة نموذج Gemma واختَر "طلب الوصول".
  • املأ نموذج الموافقة واقبل الأحكام والشروط

بعد ذلك، لاستخدام واجهة برمجة تطبيقات Kaggle، أنشئ رمزًا مميزًا لواجهة برمجة التطبيقات:

  • افتح إعدادات Kaggle.
  • انقر على إنشاء رمز مميّز جديد.
  • تم تنزيل ملف kaggle.json. تحتوي على بيانات اعتماد Kaggle الخاصة بك

قم بتشغيل الخلية التالية وأدخل بيانات اعتماد Kaggle عندما يُطلب منك ذلك.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

هناك طريقة بديلة، وهي ضبط KAGGLE_USERNAME وKAGGLE_KEY في بيئتك إذا لم يعمل kagglehub.login() .

تثبيت

تثبيت Keras و KerasNLP باستخدام نموذج Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

إعداد الواجهة الخلفية لـ Keras JAX

استورِد JAX وأجرِ فحصًا للسلامة على TPU. تقدم Kaggle أجهزة TPUv3-8 التي تحتوي على 8 أنوية TPU مع ذاكرة 16 غيغابايت لكل منها.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

تحميل النموذج

import keras
import keras_nlp

ملاحظات حول تطبيق دقة مختلطة على وحدات معالجة الرسومات NVIDIA

عند التدريب على وحدات معالجة الرسومات NVIDIA، يمكن استخدام الدقة المختلطة (keras.mixed_precision.set_global_policy('mixed_bfloat16')) لتسريع عملية التدريب بأقل تأثير ممكن على جودة التدريب. ويُنصح في معظم الحالات بتفعيل الدقة المختلطة لأنها توفر كلاً من الذاكرة والوقت. ومع ذلك، انتبه إلى أنه في أحجام الدُفعات الصغيرة، يمكن أن تزيد من استخدام الذاكرة بمقدار 1.5 مرة (سيتم تحميل القيم مرتين، بنصف دقة وبدقة كاملة).

للاستنتاج، سيعمل نصف الدقة (keras.config.set_floatx("bfloat16")) على توفير الذاكرة في حين لا يمكن تطبيق الدقة المختلطة.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

لتحميل النموذج بالأوزان والأوزان الموزَّعة على وحدات معالجة الموتّرات، عليك أولاً إنشاء DeviceMesh جديد. يمثّل DeviceMesh مجموعة من الأجهزة التي تم إعدادها لإجراء العمليات الحسابية الموزعة وتم تقديمها في Keras 3 كجزء من واجهة برمجة التطبيقات الموحّدة للتوزيع.

تُتيح واجهة برمجة التطبيقات للتوزيع إمكانية توازي البيانات والنماذج، ما يسمح بتوسيع نطاق نماذج التعليم المعمّق بفعالية من خلال مسرِّعات أعمال ومضيفين متعددين. تستفيد من الإطار الأساسي (على سبيل المثال، JAX) لتوزيع البرنامج والموترات وفقًا للتوجيهات المقسمة من خلال إجراء يسمى برنامج واحد، توسع البيانات المتعددة (SPMD). يمكنك الاطّلاع على مزيد من التفاصيل في دليل واجهة برمجة التطبيقات Keras 3 للتوزيع الجديد.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

يحدّد LayoutMap من واجهة برمجة التطبيقات للتوزيع كيفية تجزئة الأوزان والأضداد أو نسخها باستخدام مفاتيح السلسلة، على سبيل المثال، token_embedding/embeddings أدناه، والتي يتم التعامل معها على غرار التعبير العادي لمطابقة مسارات الموتّر. تتم تجزئة الموجات المطابقة بأبعاد النموذج (8 وحدات معالجة مركزية)؛ وسيتم نسخ العناصر الأخرى بالكامل.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

تسمح لك ModelParallel بتقسيم قيم ترجيح النماذج أو أشدها التفعيل على جميع الأجهزة على DeviceMesh. في هذه الحالة، تتمّ تجزئة بعض أوزان نموذج Gemma 7B على 8 شرائح بولي يورثان متلدّن بالحرارة وفقًا لمعيار layout_map المحدّد أعلاه. الآن قم بتحميل النموذج بالطريقة الموزعة.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

تحقَّق الآن من تقسيم النموذج بشكل صحيح. لنلقِ نظرة على decoder_block_1 كمثال.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

الاستنتاج قبل الضبط

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

ينشئ النموذج قائمة بالأفلام الكوميدية الرائعة من التسعينيات. ونعمل الآن على تحسين نموذج جيما لتغيير نمط الإخراج.

الضبط الدقيق باستخدام IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

نفِّذ الضبط باستخدام تعديل الترتيب المنخفض (LoRA). يعتبر LoRA تقنية ضبط تقلل بشكل كبير من عدد المعلمات القابلة للتدريب للمهام النهائية عن طريق تجميد الأوزان الكاملة للنموذج وإدراج عدد أصغر من الترجيح الجديدة القابلة للتدريب في النموذج. في الأساس، تعيد لورا ضبط المصفوفات الأكبر حجمًا ذات الوزن الكامل بواسطة مصفوفتين أصغر حجمًا ومنخفضة الترتيب AxB لتدريبها وهذه التقنية تجعل التدريب أسرع وأكثر كفاءة في الذاكرة.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

تجدر الإشارة إلى أنّ تفعيل LoRA يقلّل بشكل كبير من عدد المعلَمات القابلة للتدريب، من 7 مليار إلى 11 مليون معلَمات فقط.

الاستنتاج بعد الضبط

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

بعد إجراء التعديلات، تعلّم النموذج نمط مراجعات الأفلام وهي الآن تُنتج المخرجات بهذا الأسلوب في سياق الأفلام الكوميدية في التسعينيات.

الخطوات التالية

تعلّمت في هذا البرنامج التعليمي كيفية استخدام خلفية KerasNLP JAX لضبط نموذج جيما على مجموعة بيانات IMDb بشكل موزَّع على وحدات معالجة الموتّرات القوية. فيما يلي بعض الاقتراحات لما يجب تعلمه أيضًا: