تنظیم توزیع شده با Gemma با استفاده از Keras

مشاهده در ai.google.dev در Kaggle بدوید در Vertex AI باز کنید مشاهده منبع در GitHub

بررسی اجمالی

Gemma خانواده ای از مدل های باز سبک وزن و پیشرفته است که بر اساس تحقیقات و فناوری ساخته شده است که برای ایجاد مدل های Google Gemini استفاده می شود. Gemma را می توان بیشتر برای مطابقت با نیازهای خاص تنظیم کرد. اما مدل‌های زبان بزرگ، مانند جما، می‌توانند از نظر اندازه بسیار بزرگ باشند و برخی از آنها ممکن است برای تنظیم دقیق روی شتاب‌دهنده آواز قرار نگیرند. در این مورد دو رویکرد کلی برای تنظیم دقیق آنها وجود دارد:

  1. پارامتر کارآمد تنظیم دقیق (PEFT)، که به دنبال کاهش اندازه موثر مدل با قربانی کردن برخی وفاداری است. LoRA در این دسته قرار می گیرد و آموزش تنظیم دقیق مدل های Gemma در Keras با استفاده از LoRA نحوه تنظیم دقیق مدل Gemma 2B gemma_2b_en با LoRA با استفاده از KerasNLP روی یک GPU واحد را نشان می دهد.
  2. تنظیم دقیق پارامتر کامل با موازی سازی مدل. موازی سازی مدل، وزن های یک مدل را در چندین دستگاه توزیع می کند و مقیاس افقی را امکان پذیر می کند. می‌توانید در این راهنمای Keras درباره آموزش توزیع‌شده اطلاعات بیشتری کسب کنید.

این آموزش شما را با استفاده از Keras با یک بک‌اند JAX راهنمایی می‌کند تا مدل Gemma 7B را با LoRA و آموزش توزیع‌شده مدل موازی در واحد پردازش تنسور (TPU) Google تنظیم کنید. توجه داشته باشید که LoRA را می توان در این آموزش برای تنظیم کامل پارامتر آهسته تر اما دقیق تر خاموش کرد.

استفاده از شتاب دهنده ها

از نظر فنی می توانید از TPU یا GPU برای این آموزش استفاده کنید.

نکاتی در مورد محیط های TPU

Google 3 محصول دارد که TPU ارائه می کنند:

  • Colab TPU v2 را ارائه می دهد که برای این آموزش کافی نیست.
  • Kaggle TPU v3 را به صورت رایگان ارائه می دهد و آنها برای این آموزش کار می کنند.
  • Cloud TPU TPU v3 و نسل های جدیدتر را ارائه می دهد. یکی از راه های تنظیم آن این است:
    1. یک TPU VM جدید ایجاد کنید
    2. انتقال پورت SSH را برای پورت سرور Jupyter مورد نظر خود تنظیم کنید
    3. Jupyter را نصب کنید و آن را در TPU VM راه اندازی کنید، سپس از طریق "Connect to a local runtime" به Colab متصل شوید.

نکاتی در مورد راه اندازی چند GPU

اگرچه این آموزش بر روی مورد استفاده از TPU تمرکز دارد، اما اگر یک ماشین چند GPU دارید، به راحتی می توانید آن را برای نیازهای خود تطبیق دهید.

اگر ترجیح می‌دهید از طریق Colab کار کنید، می‌توانید مستقیماً از طریق «اتصال به یک ماشین مجازی GCE سفارشی» در منوی Colab Connect، یک VM چند GPU برای Colab تهیه کنید.

ما در اینجا بر استفاده از TPU رایگان Kaggle تمرکز خواهیم کرد.

قبل از اینکه شروع کنی

اعتبارنامه Kaggle

مدل های Gemma توسط Kaggle میزبانی می شود. برای استفاده از Gemma، درخواست دسترسی در Kaggle کنید:

  • وارد شوید یا در kaggle.com ثبت نام کنید
  • کارت مدل Gemma را باز کنید و "Request Access" را انتخاب کنید.
  • فرم رضایت نامه را تکمیل کنید و شرایط و ضوابط را بپذیرید

سپس، برای استفاده از Kaggle API، یک توکن API ایجاد کنید:

  • تنظیمات Kaggle را باز کنید
  • "ایجاد توکن جدید" را انتخاب کنید
  • یک فایل kaggle.json دانلود می شود. این شامل اعتبارنامه Kaggle شما است

سلول زیر را اجرا کنید و زمانی که از شما خواسته شد، اعتبار Kaggle خود را وارد کنید.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

اگر ()kagglehub.login برای شما کار نمی کند، یک راه جایگزین این است که KAGGLE_USERNAME و KAGGLE_KEY را در محیط خود تنظیم کنید.

نصب و راه اندازی

Keras و KerasNLP را با مدل Gemma نصب کنید.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX backend را راه اندازی کنید

JAX را وارد کنید و یک بررسی سلامت روی TPU انجام دهید. Kaggle دستگاه های TPUv3-8 را ارائه می دهد که دارای 8 هسته TPU با 16 گیگابایت حافظه هستند.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

مدل بارگذاری

import keras
import keras_nlp

نکاتی در مورد آموزش با دقت ترکیبی در پردازنده‌های گرافیکی NVIDIA

هنگام آموزش روی پردازنده‌های گرافیکی NVIDIA، می‌توان از دقت ترکیبی ( keras.mixed_precision.set_global_policy('mixed_bfloat16') ) برای افزایش سرعت آموزش با حداقل تأثیر بر کیفیت آموزش استفاده کرد. در بیشتر موارد، توصیه می شود که دقت ترکیبی را روشن کنید زیرا باعث صرفه جویی در حافظه و زمان می شود. با این حال، توجه داشته باشید که در اندازه‌های دسته‌ای کوچک، می‌تواند میزان استفاده از حافظه را 1.5 برابر افزایش دهد (وزن‌ها دو بار بارگذاری می‌شوند، با دقت نصف و با دقت کامل).

برای استنباط، نیم دقت ( keras.config.set_floatx("bfloat16") ) کار می کند و حافظه را ذخیره می کند در حالی که دقت ترکیبی قابل استفاده نیست.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

برای بارگذاری مدل با وزن ها و تانسورهای توزیع شده در TPU ها، ابتدا یک DeviceMesh جدید ایجاد کنید. DeviceMesh مجموعه ای از دستگاه های سخت افزاری را نشان می دهد که برای محاسبات توزیع شده پیکربندی شده اند و در Keras 3 به عنوان بخشی از API توزیع یکپارچه معرفی شده اند.

API توزیع، موازی سازی داده ها و مدل ها را امکان پذیر می کند و امکان مقیاس بندی کارآمد مدل های یادگیری عمیق را در چندین شتاب دهنده و میزبان فراهم می کند. از چارچوب زیربنایی (مانند JAX) برای توزیع برنامه و تانسورها بر اساس دستورالعمل‌های اشتراک‌گذاری از طریق رویه‌ای به نام تک‌برنامه، توسعه داده‌های چندگانه (SPMD) استفاده می‌کند. جزئیات بیشتر را در راهنمای جدید API توزیع Keras 3 بررسی کنید.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap از API توزیع، نحوه تقسیم یا تکثیر وزن‌ها و تانسورها را با استفاده از کلیدهای رشته‌ای، به‌عنوان مثال، token_embedding/embeddings در زیر، مشخص می‌کند که برای مطابقت با مسیرهای تانسور مانند regex رفتار می‌شود. تانسورهای منطبق با ابعاد مدل (8 TPU) خرد شده اند. بقیه به طور کامل تکرار خواهند شد.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel به شما امکان می دهد وزن های مدل یا تانسورهای فعال سازی را در همه دستگاه های موجود در DeviceMesh خرد کنید. در این مورد، برخی از وزن های مدل Gemma 7B بر اساس layout_map که در بالا تعریف شده است، در بین 8 تراشه TPU تقسیم می شوند. اکنون مدل را به روش توزیع شده بارگذاری کنید.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

اکنون بررسی کنید که مدل به درستی پارتیشن بندی شده است. بیایید decoder_block_1 به عنوان مثال در نظر بگیریم.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

استنتاج قبل از تنظیم دقیق

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

این مدل لیستی از فیلم های کمدی عالی دهه 90 را برای تماشا ایجاد می کند. اکنون مدل Gemma را برای تغییر سبک خروجی تنظیم می کنیم.

با IMDB هماهنگ شوید

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

تنظیم دقیق را با استفاده از تطبیق رتبه پایین (LoRA) انجام دهید. LoRA یک تکنیک تنظیم دقیق است که تعداد پارامترهای قابل آموزش برای کارهای پایین دستی را با انجماد وزنه های کامل مدل و قرار دادن تعداد کمتری وزنه های قابل آموزش جدید در مدل به میزان زیادی کاهش می دهد. اساساً LoRA ماتریس‌های بزرگ‌تر با وزن کامل را با 2 ماتریس کوچک‌تر AxB با رتبه پایین‌تر برای تمرین مجددا پارامترسازی می‌کند و این تکنیک باعث می‌شود تمرین بسیار سریع‌تر و حافظه کارآمدتر شود.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

توجه داشته باشید که فعال کردن LoRA تعداد پارامترهای قابل آموزش را به میزان قابل توجهی کاهش می دهد، از 7 میلیارد به تنها 11 میلیون.

استنباط پس از تنظیم دقیق

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

پس از تنظیم دقیق، این مدل سبک نقد فیلم را آموخته و اکنون در حال تولید خروجی در آن سبک در زمینه فیلم‌های کمدی دهه 90 است.

بعدش چی

در این آموزش، نحوه استفاده از KerasNLP JAX Backend را برای تنظیم دقیق مدل Gemma در مجموعه داده های IMDb به صورت توزیع شده در TPU های قدرتمند یاد گرفتید. در اینجا چند پیشنهاد برای چیزهای دیگری برای یادگیری وجود دارد: