ضبط موزّع بالتعاون مع "جيما" باستخدام Keras

العرض على ai.google.dev التنفيذ في Google Colab التنفيذ في Kaggle الفتح في Vertex AI عرض المصدر على GitHub

نظرة عامة

"جيما" هي مجموعة من النماذج المتطوّرة والخفيفة المتاحة للجميع، والتي تم إنشاؤها استنادًا إلى الأبحاث والتكنولوجيا المستخدَمة لإنشاء نماذج Google Gemini. يمكن إجراء تحسينات إضافية على Gemma لتناسب الاحتياجات الخاصة. ولكن النماذج اللغوية الكبيرة، مثل Gemma، قد تكون كبيرة جدًا في الحجم وقد لا يتناسب بعضها مع مسرِّع الغناء من أجل الضبط الدقيق. في هذه الحالة، هناك نهجان عامان لضبطهما:

  1. الضبط الدقيق للمعلمة (PEFT) الذي يسعى إلى تقليص حجم النموذج الفعال من خلال التضحية ببعض الدقة. تندرج لورا ضمن هذه الفئة، ويشرح النموذج الدقيق لنماذج Gemma في Keras باستخدام LoRA كيفية تحسين نموذج Gemma 2B gemma_2b_en من خلال LoRA باستخدام KerasNLP في وحدة معالجة رسومات واحدة.
  2. الضبط الكامل للمَعلمة باستخدام موازاة النموذج توزع موازاة النموذج أوزان نموذج واحد على أجهزة متعددة وتتيح القياس الأفقي. يمكنك معرفة المزيد من المعلومات عن التدريب الموزَّع في دليل "كيرا" هذا.

يرشدك هذا البرنامج التعليمي إلى كيفية استخدام Keras مع خلفية JAX لتحسين نموذج Gemma 7B باستخدام LoRA والتدريب الموزَّع للنماذج التكافؤية ضمن وحدة معالجة Tensor (TPU) من Google. لاحظ أنه يمكن إيقاف LoRA في هذا البرنامج التعليمي لضبط المعلمة الكاملة بشكل أبطأ وأكثر دقة.

استخدام مسرِّعات الأعمال

من الناحية الفنية، يمكنك استخدام وحدة معالجة الموتّرات أو وحدة معالجة الرسومات في هذا البرنامج التعليمي.

ملاحظات حول بيئات وحدة معالجة الموتّرات

تقدّم Google 3 منتجات توفّر وحدات معالجة الموتّرات:

  • يوفّر Colab الإصدار الثاني من TPU مجانًا، وهو ما يكفي لهذا الدليل التوجيهي.
  • يوفر Kaggle الإصدار 3 من TPU مجانًا، كما يمكنه أيضًا استخدام هذا البرنامج التعليمي.
  • تتوفّر في Cloud TPU وحدة معالجة الموتّرات (TPU) الإصدار 3 والأجيال الأحدث. إحدى طرق إعداده هي:
    1. إنشاء جهاز افتراضي TPU جديد
    2. إعداد إعادة توجيه منفذ SSH لمنفذ خادم Jupyter المقصود
    3. ثبِّت تطبيق Jupyter وشغِّله على جهاز TPU الافتراضي، ثم اتصِل بوحدة Colab من خلال "الاتصال ببيئة تشغيل على الجهاز".

ملاحظات حول إعداد وحدات معالجة رسومات متعددة

على الرغم من أنّ هذا البرنامج التعليمي يركّز على حالة استخدام وحدة معالجة الموتّرات، يمكنك تعديلها بسهولة لتلبية احتياجاتك الخاصة إذا كنت تستخدم جهازًا يتضمّن وحدة معالجة رسومات متعدّدة.

إذا كنت تفضِّل العمل من خلال Colab، من الممكن أيضًا توفير جهاز افتراضي متعدد معالجة الرسومات في Colab مباشرةً من خلال خيار "الاتصال بجهاز GCE افتراضي مخصّص". في قائمة Colab Connect.

سنركّز على استخدام وحدة معالجة الموتّرات المجانية من Kaggle هنا.

قبل البدء

بيانات اعتماد Kaggle

تتم استضافة نماذج Gemma بواسطة Kaggle. لاستخدام Gemma، اطلب الوصول على Kaggle:

  • تسجيل الدخول أو التسجيل على kaggle.com
  • افتح بطاقة نموذج Gemma واختَر "طلب الوصول".
  • إكمال نموذج الموافقة وقبول الأحكام والشروط

بعد ذلك، لاستخدام واجهة برمجة تطبيقات Kaggle، قم بإنشاء رمز مميز لواجهة برمجة التطبيقات:

  • افتح إعدادات Kaggle
  • اختيار "إنشاء رمز مميّز جديد"
  • تم تنزيل ملف kaggle.json. يحتوي على بيانات اعتماد Kaggle

قم بتشغيل الخلية التالية وأدخل بيانات اعتماد Kaggle عندما يُطلب منك ذلك.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

طريقة بديلة هي تعيين KAGGLE_USERNAME وKAGGLE_KEY في بيئتك إذا كان kagglehub.login() لا يناسبك.

تثبيت

تثبيت Keras وKerasNLP مع نموذج Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

إعداد خلفية Keras JAX

يمكنك استيراد JAX وإجراء فحص من سلامة الجهاز على وحدة معالجة الموتّرات. تقدم Kaggle أجهزة TPUv3-8 التي تحتوي على 8 نوى TPU مع ذاكرة تبلغ 16 غيغابايت لكل منها.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

تحميل النموذج

import keras
import keras_nlp

ملاحظات حول التدريب على الدقة المختلطة في وحدات معالجة الرسومات NVIDIA

عند التدريب على وحدات معالجة الرسومات NVIDIA، يمكن استخدام الدقة المختلطة (keras.mixed_precision.set_global_policy('mixed_bfloat16')) لتسريع التدريب بأقل تأثير ممكن في جودة التدريب. في معظم الحالات، يُنصح بتفعيل نظام الدقة المختلط لأنّه يوفّر كلّاً من الذاكرة والوقت. ومع ذلك، يجب الانتباه إلى أنّه يمكن تضخيم استخدام الذاكرة بمقدار 1.5 مرة في أحجام الدُفعات الصغيرة (سيتم تحميل القيم التقديرية مرتين بنصف الدقة والدقة الكاملة).

للاستنتاج، ستعمل دقة نصف الدقة (keras.config.set_floatx("bfloat16")) وتوفّر مساحة الذاكرة بينما لا تكون الدقة المختلطة غير سارية.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

لتحميل النموذج الذي يحتوي على قيم الأوزان وأدوات قياس الأوزان الموزعة على وحدات معالجة الموتّرات، يجب أولاً إنشاء سمة DeviceMesh جديدة. يمثّل DeviceMesh مجموعة من الأجهزة التي تم إعدادها للحوسبة الموزعة وتم طرحها في Keras 3 كجزء من واجهة برمجة التطبيقات الموحدة للتوزيع.

تتيح واجهة برمجة التطبيقات للتوزيع إمكانية موازاة البيانات والنماذج، ما يسمح بتوسيع نطاق نماذج التعلم المتعمّق بفعالية على برامج مسرِّعات ومضيفات متعددة. وتستفيد هذه الخدمة من إطار العمل الأساسي (مثل JAX) لتوزيع البرنامج ومُؤسّسات الرسم وفقًا لتوجيهات التقسيم من خلال إجراء يُسمى برنامج فردي، أو توسيع البيانات المتعددة (SPMD). يمكنك الاطّلاع على مزيد من التفاصيل في دليل واجهة برمجة التطبيقات Keras 3 distribution API الجديد.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

تحدّد الدالة LayoutMap من واجهة برمجة التطبيقات للتوزيع طريقة تجزئة قيم الترجيح ومسؤولات التنس أو تكرارها، باستخدام مفاتيح السلسلة، على سبيل المثال، token_embedding/embeddings أدناه، والتي يتم التعامل معها مثل التعبير العادي لمطابقة مسارات Tenor. يتم تقسيم مُوصِلات الوتيرة المطابقة باستخدام أبعاد النموذج (8 وحدات معالجة موتّرة). فسيتم نسخ بعضها الآخر بالكامل.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

تسمح لك الدالة ModelParallel بتقسيم ترجيحات النموذج أو دوالّ التفعيل في جميع الأجهزة على DeviceMesh. في هذه الحالة، يتم تقسيم بعض القيم التقديرية لطراز Gemma 7B على 8 شرائح TPU وفقًا للسمة layout_map المحدّدة أعلاه. والآن قم بتحميل النموذج بالطريقة الموزعة.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

والآن تحقق من أنه تم تقسيم النموذج بشكل صحيح. لِنأخذ decoder_block_1 كمثال.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

الاستنتاج قبل الضبط

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

ينشئ النموذج قائمة بالأفلام الكوميدية الرائعة من التسعينيات لمشاهدتها. والآن نقوم بضبط نموذج Gemma لتغيير نمط الإخراج.

الضبط الدقيق باستخدام IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

أجرِ عملية ضبط باستخدام ميزة تعديل الترتيب المنخفض (LoRA). إن LoRA هي تقنية ضبط دقيقة تقلل بشكل كبير من عدد المعلمات القابلة للتدريب للمهام التي يتم تنفيذها عن طريق تجميد الأوزان الكاملة للنموذج وإدخال عدد أقل من الأوزان الجديدة القابلة للتدريب في النموذج. في الأساس، تقوم LoRA بإعادة معلَمة المصفوفات الأكبر ذات الوزن الكامل باستخدام مصفوفتين أصغر حجمًا AxB من حيث الترتيب، وهذه التقنية تجعل التدريب أسرع بكثير وأكثر كفاءة في الذاكرة.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

لاحظ أن تفعيل LoRA يقلل عدد المعلمات القابلة للتدريب بشكل كبير، من 7 مليارات إلى 11 مليونًا فقط.

الاستنتاج بعد الضبط

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

بعد تحسين المحتوى، تعلّم صانعة المحتوى أسلوب مراجعات الأفلام، وبدأت بإنشاء فيديوهات بهذا الأسلوب في سياق أفلام كوميدية من التسعينيات.

الخطوات التالية

في هذا البرنامج التعليمي، تعلّمت كيفية استخدام خلفية KerasNLP JAX لتحسين نموذج Gemma في مجموعة بيانات IMDb بطريقة موزَّعة باستخدام وحدات معالجة الموتّرات الفعّالة. إليك بعض الاقتراحات لما يجب تعلمه أيضًا: