مشاهده در ai.google.dev | در Google Colab اجرا شود | در Kaggle بدوید | در Vertex AI باز کنید | مشاهده منبع در GitHub |
نمای کلی
Gemma خانواده ای از مدل های باز سبک وزن و پیشرفته است که بر اساس تحقیقات و فناوری ساخته شده است که برای ایجاد مدل های Google Gemini استفاده می شود. Gemma را می توان بیشتر برای مطابقت با نیازهای خاص تنظیم کرد. اما مدلهای زبان بزرگ، مانند جما، میتوانند از نظر اندازه بسیار بزرگ باشند و برخی از آنها ممکن است برای تنظیم دقیق روی شتابدهنده آواز قرار نگیرند. در این مورد دو رویکرد کلی برای تنظیم دقیق آنها وجود دارد:
- پارامتر کارآمد تنظیم دقیق (PEFT)، که به دنبال کاهش اندازه موثر مدل با قربانی کردن برخی وفاداری است. LoRA در این دسته قرار می گیرد و آموزش تنظیم دقیق مدل های Gemma در Keras با استفاده از LoRA نحوه تنظیم دقیق مدل Gemma 2B
gemma_2b_en
با LoRA با استفاده از KerasNLP روی یک GPU واحد را نشان می دهد. - تنظیم دقیق پارامتر کامل با موازی سازی مدل. موازی سازی مدل، وزن های یک مدل را در چندین دستگاه توزیع می کند و مقیاس افقی را امکان پذیر می کند. میتوانید در این راهنمای Keras درباره آموزش توزیعشده اطلاعات بیشتری کسب کنید.
این آموزش شما را با استفاده از Keras با یک بکاند JAX راهنمایی میکند تا مدل Gemma 7B را با LoRA و آموزش توزیعشده مدل موازی در واحد پردازش تنسور (TPU) Google تنظیم کنید. توجه داشته باشید که LoRA را می توان در این آموزش برای تنظیم کامل پارامتر آهسته تر اما دقیق تر خاموش کرد.
استفاده از شتاب دهنده ها
از نظر فنی می توانید از TPU یا GPU برای این آموزش استفاده کنید.
نکاتی در مورد محیط های TPU
Google 3 محصول دارد که TPU ارائه می کنند:
- کولب TPU v2 را به صورت رایگان ارائه می کند که برای این آموزش کافی است.
- Kaggle TPU v3 را به صورت رایگان ارائه می دهد و آنها نیز برای این آموزش کار می کنند.
- Cloud TPU TPU v3 و نسل های جدیدتر را ارائه می دهد. یکی از راه های تنظیم آن این است:
- یک TPU VM جدید ایجاد کنید
- انتقال پورت SSH را برای پورت سرور Jupyter مورد نظر خود تنظیم کنید
- Jupyter را نصب کنید و آن را در TPU VM راه اندازی کنید، سپس از طریق "Connect to a local runtime" به Colab متصل شوید.
نکاتی در مورد راه اندازی چند GPU
اگرچه این آموزش بر روی مورد استفاده از TPU تمرکز دارد، اما اگر یک ماشین چند GPU دارید، به راحتی می توانید آن را برای نیازهای خود تطبیق دهید.
اگر ترجیح میدهید از طریق Colab کار کنید، میتوانید مستقیماً از طریق «اتصال به یک ماشین مجازی GCE سفارشی» در منوی Colab Connect، یک VM چند GPU برای Colab تهیه کنید.
ما در اینجا بر استفاده از TPU رایگان Kaggle تمرکز خواهیم کرد.
قبل از شروع
اعتبارنامه Kaggle
مدل های Gemma توسط Kaggle میزبانی می شود. برای استفاده از Gemma، درخواست دسترسی در Kaggle کنید:
- وارد شوید یا در kaggle.com ثبت نام کنید
- کارت مدل Gemma را باز کنید و "Request Access" را انتخاب کنید.
- فرم رضایت نامه را تکمیل کنید و شرایط و ضوابط را بپذیرید
سپس، برای استفاده از Kaggle API، یک توکن API ایجاد کنید:
- تنظیمات Kaggle را باز کنید
- "ایجاد توکن جدید" را انتخاب کنید
- یک فایل
kaggle.json
دانلود می شود. این شامل اعتبارنامه Kaggle شما است
سلول زیر را اجرا کنید و زمانی که از شما خواسته شد، اعتبار Kaggle خود را وارد کنید.
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
اگر ()kagglehub.login برای شما کار نمی کند، یک راه جایگزین این است که KAGGLE_USERNAME و KAGGLE_KEY را در محیط خود تنظیم کنید.
نصب و راه اندازی
Keras و KerasNLP را با مدل Gemma نصب کنید.
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
باطن Keras JAX را راه اندازی کنید
JAX را وارد کنید و یک بررسی سلامت روی TPU انجام دهید. Kaggle دستگاه های TPUv3-8 را ارائه می دهد که دارای 8 هسته TPU با 16 گیگابایت حافظه هستند.
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
مدل بارگذاری
import keras
import keras_nlp
نکاتی در مورد آموزش با دقت ترکیبی در پردازندههای گرافیکی NVIDIA
هنگام آموزش روی پردازندههای گرافیکی NVIDIA، میتوان از دقت ترکیبی ( keras.mixed_precision.set_global_policy('mixed_bfloat16')
) برای افزایش سرعت آموزش با حداقل تأثیر بر کیفیت آموزش استفاده کرد. در بیشتر موارد، توصیه می شود که دقت ترکیبی را روشن کنید زیرا باعث صرفه جویی در حافظه و زمان می شود. با این حال، توجه داشته باشید که در اندازههای دستهای کوچک، میتواند میزان استفاده از حافظه را 1.5 برابر افزایش دهد (وزنها دو بار بارگذاری میشوند، با دقت نصف و با دقت کامل).
برای استنباط، نیم دقت ( keras.config.set_floatx("bfloat16")
) کار می کند و حافظه را ذخیره می کند در حالی که دقت ترکیبی قابل استفاده نیست.
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
برای بارگذاری مدل با وزن ها و تانسورهای توزیع شده در TPU ها، ابتدا یک DeviceMesh
جدید ایجاد کنید. DeviceMesh
مجموعه ای از دستگاه های سخت افزاری را نشان می دهد که برای محاسبات توزیع شده پیکربندی شده اند و در Keras 3 به عنوان بخشی از API توزیع یکپارچه معرفی شده اند.
API توزیع، موازی سازی داده ها و مدل ها را امکان پذیر می کند و امکان مقیاس بندی کارآمد مدل های یادگیری عمیق را در چندین شتاب دهنده و میزبان فراهم می کند. از چارچوب زیربنایی (مانند JAX) برای توزیع برنامه و تانسورها بر اساس دستورالعملهای اشتراکگذاری از طریق رویهای به نام تکبرنامه، توسعه دادههای چندگانه (SPMD) استفاده میکند. جزئیات بیشتر را در راهنمای جدید API توزیع Keras 3 بررسی کنید.
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
LayoutMap
از API توزیع، نحوه تقسیم یا تکثیر وزنها و تانسورها را با استفاده از کلیدهای رشتهای، بهعنوان مثال، token_embedding/embeddings
در زیر، مشخص میکند که برای مطابقت با مسیرهای تانسور مانند regex رفتار میشود. تانسورهای منطبق با ابعاد مدل (8 TPU) خرد شده اند. بقیه به طور کامل تکرار خواهند شد.
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
ModelParallel
به شما امکان می دهد وزن های مدل یا تانسورهای فعال سازی را در همه دستگاه های موجود در DeviceMesh
خرد کنید. در این مورد، برخی از وزن های مدل Gemma 7B بر اساس layout_map
که در بالا تعریف شده است، در بین 8 تراشه TPU تقسیم می شوند. اکنون مدل را به روش توزیع شده بارگذاری کنید.
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
اکنون بررسی کنید که مدل به درستی پارتیشن بندی شده است. بیایید decoder_block_1
به عنوان مثال در نظر بگیریم.
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
استنتاج قبل از تنظیم دقیق
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
این مدل لیستی از فیلم های کمدی عالی دهه 90 را برای تماشا ایجاد می کند. اکنون مدل Gemma را برای تغییر سبک خروجی تنظیم می کنیم.
با IMDB هماهنگ شوید
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
تنظیم دقیق را با استفاده از تطبیق رتبه پایین (LoRA) انجام دهید. LoRA یک تکنیک تنظیم دقیق است که تعداد پارامترهای قابل آموزش برای کارهای پایین دستی را با انجماد وزنه های کامل مدل و قرار دادن تعداد کمتری وزنه های قابل آموزش جدید در مدل به میزان زیادی کاهش می دهد. اساساً LoRA ماتریسهای بزرگتر با وزن کامل را با 2 ماتریس کوچکتر AxB با رتبه پایینتر برای تمرین مجددا پارامترسازی میکند و این تکنیک باعث میشود تمرین بسیار سریعتر و حافظه کارآمدتر شود.
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
توجه داشته باشید که فعال کردن LoRA تعداد پارامترهای قابل آموزش را به میزان قابل توجهی کاهش می دهد، از 7 میلیارد به تنها 11 میلیون.
استنباط پس از تنظیم دقیق
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
پس از تنظیم دقیق، این مدل سبک نقد فیلم را آموخته و اکنون در حال تولید خروجی در آن سبک در زمینه فیلمهای کمدی دهه 90 است.
بعدش چی
در این آموزش، نحوه استفاده از KerasNLP JAX Backend را برای تنظیم دقیق مدل Gemma در مجموعه داده های IMDb به صورت توزیع شده در TPU های قدرتمند یاد گرفتید. در اینجا چند پیشنهاد برای چیزهای دیگری برای یادگیری وجود دارد:
- نحوه شروع کار با Keras Gemma را بیاموزید.
- نحوه تنظیم دقیق مدل Gemma در GPU را بیاموزید.