কেরাস ব্যবহার করে জেমার সাথে টিউনিং বিতরণ করা হয়েছে

ai.google.dev-এ দেখুন Google Colab-এ চালান কাগলে চালান Vertex AI-তে খুলুন GitHub-এ উৎস দেখুন

ওভারভিউ

Gemma হল হালকা ওজনের, অত্যাধুনিক ওপেন মডেলের একটি পরিবার যা Google Gemini মডেল তৈরি করতে ব্যবহৃত গবেষণা এবং প্রযুক্তি থেকে তৈরি। নির্দিষ্ট প্রয়োজন অনুসারে জেমাকে আরও সুন্দর করা যেতে পারে। কিন্তু বড় ভাষার মডেল, যেমন জেমা, আকারে অনেক বড় হতে পারে এবং তাদের মধ্যে কিছু ফিনটিউনিংয়ের জন্য একটি সিং এক্সিলারেটরে ফিট নাও হতে পারে। এই ক্ষেত্রে তাদের ফাইনটিউন করার জন্য দুটি সাধারণ পদ্ধতি রয়েছে:

  1. পরামিতি দক্ষ ফাইন-টিউনিং (PEFT), যা কিছু বিশ্বস্ততা বলি দিয়ে কার্যকর মডেলের আকার সঙ্কুচিত করতে চায়। LoRA এই বিভাগে পড়ে এবং LoRA টিউটোরিয়াল ব্যবহার করে কেরাসের ফাইন-টিউন জেম্মা মডেলগুলি দেখায় যে কীভাবে একটি একক GPU-তে KerasNLP ব্যবহার করে LoRA-এর সাথে Gemma 2B মডেল gemma_2b_en ফাইনটিউন করা যায়।
  2. মডেলের সমান্তরালতার সাথে সম্পূর্ণ প্যারামিটার ফাইনটিউনিং। মডেল সমান্তরালতা একাধিক ডিভাইস জুড়ে একটি একক মডেলের ওজন বিতরণ করে এবং অনুভূমিক স্কেলিং সক্ষম করে। আপনি এই কেরাস গাইডে বিতরণ করা প্রশিক্ষণ সম্পর্কে আরও জানতে পারেন।

এই টিউটোরিয়ালটি আপনাকে JAX ব্যাকএন্ডের সাথে কেরাস ব্যবহার করে LoRA-এর সাথে Gemma 7B মডেল এবং Google-এর টেনসর প্রসেসিং ইউনিট (TPU)-তে মডেল-সমান্তরিত প্রশিক্ষণ প্রদান করে। মনে রাখবেন যে এই টিউটোরিয়ালে LoRA বন্ধ করা যেতে পারে একটি ধীর কিন্তু আরও সঠিক পূর্ণ-প্যারামিটার টিউনিংয়ের জন্য।

এক্সিলারেটর ব্যবহার করে

প্রযুক্তিগতভাবে আপনি এই টিউটোরিয়ালের জন্য TPU বা GPU ব্যবহার করতে পারেন।

TPU পরিবেশের উপর নোট

Google-এর 3টি পণ্য রয়েছে যা TPU প্রদান করে:

  • Colab বিনামূল্যে TPU v2 প্রদান করে, যা এই টিউটোরিয়ালের জন্য যথেষ্ট।
  • Kaggle বিনামূল্যে TPU v3 অফার করে এবং তারা এই টিউটোরিয়ালের জন্যও কাজ করে।
  • ক্লাউড TPU টিপিইউ v3 এবং নতুন প্রজন্মের অফার করে। এটি সেট আপ করার একটি উপায় হল:
    1. একটি নতুন TPU VM তৈরি করুন
    2. আপনার অভিপ্রেত জুপিটার সার্ভার পোর্টের জন্য SSH পোর্ট ফরওয়ার্ডিং সেট আপ করুন
    3. Jupyter ইনস্টল করুন এবং এটি TPU VM-এ চালু করুন, তারপর "স্থানীয় রানটাইমে কানেক্ট করুন"-এর মাধ্যমে Colab-এর সাথে কানেক্ট করুন

মাল্টি-জিপিইউ সেটআপের নোট

যদিও এই টিউটোরিয়ালটি টিপিইউ ব্যবহারের ক্ষেত্রে ফোকাস করে, আপনার যদি একটি মাল্টি-জিপিইউ মেশিন থাকে তবে আপনি সহজেই এটিকে আপনার নিজের প্রয়োজনে মানিয়ে নিতে পারেন।

আপনি যদি Colab-এর মাধ্যমে কাজ করতে পছন্দ করেন, তাহলে Colab Connect মেনুতে "কাস্টম GCE VM-এর সাথে কানেক্ট করুন"-এর মাধ্যমে Colab-এর জন্য একটি মাল্টি-GPU VM-এর ব্যবস্থা করাও সম্ভব।

আমরা এখানে Kaggle থেকে বিনামূল্যে TPU ব্যবহার করার উপর ফোকাস করব।

আপনি শুরু করার আগে

কাগল শংসাপত্র

জেমা মডেলগুলি কাগল দ্বারা হোস্ট করা হয়। Gemma ব্যবহার করতে, Kaggle এ অ্যাক্সেসের অনুরোধ করুন:

  • সাইন ইন করুন বা kaggle.com এ নিবন্ধন করুন
  • জেমা মডেল কার্ড খুলুন এবং "অ্যাক্সেসের অনুরোধ করুন" নির্বাচন করুন
  • সম্মতি ফর্মটি পূরণ করুন এবং শর্তাবলী গ্রহণ করুন

তারপর, Kaggle API ব্যবহার করতে, একটি API টোকেন তৈরি করুন:

  • Kaggle সেটিংস খুলুন
  • "নতুন টোকেন তৈরি করুন" নির্বাচন করুন
  • একটি kaggle.json ফাইল ডাউনলোড করা হয়। এতে আপনার কাগল শংসাপত্র রয়েছে

নিম্নলিখিত কক্ষটি চালান এবং জিজ্ঞাসা করা হলে আপনার Kaggle শংসাপত্র লিখুন৷

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

একটি বিকল্প উপায় হল আপনার পরিবেশে KAGGLE_USERNAME এবং KAGGLE_KEY সেট করা যদি kagglehub.login() আপনার জন্য কাজ না করে।

ইনস্টলেশন

জেমা মডেলের সাথে কেরাস এবং কেরাসএনএলপি ইনস্টল করুন।

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX ব্যাকএন্ড সেট আপ করুন

JAX আমদানি করুন এবং TPU-তে একটি স্যানিটী চেক চালান। Kaggle TPUv3-8 ডিভাইস অফার করে যার প্রতিটিতে 16GB মেমরি সহ 8টি TPU কোর রয়েছে।

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

লোড মডেল

import keras
import keras_nlp

NVIDIA GPU-তে মিশ্র নির্ভুল প্রশিক্ষণের নোট

NVIDIA GPU-তে প্রশিক্ষণের সময়, মিশ্র নির্ভুলতা ( keras.mixed_precision.set_global_policy('mixed_bfloat16') ) প্রশিক্ষণের মানের উপর ন্যূনতম প্রভাব সহ প্রশিক্ষণের গতি বাড়ানোর জন্য ব্যবহার করা যেতে পারে। বেশিরভাগ ক্ষেত্রে, মিশ্র নির্ভুলতা চালু করার পরামর্শ দেওয়া হয় কারণ এটি মেমরি এবং সময় উভয়ই বাঁচায়। যাইহোক, সচেতন থাকুন যে ছোট ব্যাচের আকারে, এটি মেমরির ব্যবহারকে 1.5x বৃদ্ধি করতে পারে (ওজন দুইবার লোড করা হবে, অর্ধেক নির্ভুলতা এবং সম্পূর্ণ নির্ভুলতায়)।

অনুমানের জন্য, অর্ধ-নির্ভুলতা ( keras.config.set_floatx("bfloat16") ) কাজ করবে এবং মেমরি সংরক্ষণ করবে যখন মিশ্র-নির্ভুলতা প্রযোজ্য নয়।

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

TPUs জুড়ে বিতরণ করা ওজন এবং টেনসর সহ মডেলটি লোড করতে, প্রথমে একটি নতুন DeviceMesh তৈরি করুন। DeviceMesh বিতরণ করা গণনার জন্য কনফিগার করা হার্ডওয়্যার ডিভাইসের একটি সংগ্রহের প্রতিনিধিত্ব করে এবং ইউনিফাইড ডিস্ট্রিবিউশন API-এর অংশ হিসাবে কেরাস 3-এ চালু করা হয়েছিল।

ডিস্ট্রিবিউশন এপিআই ডেটা এবং মডেলের সমান্তরালতা সক্ষম করে, একাধিক এক্সিলারেটর এবং হোস্টগুলিতে গভীর শিক্ষার মডেলগুলির দক্ষ স্কেলিং করার অনুমতি দেয়। এটি একক প্রোগ্রাম, মাল্টিপল ডাটা (এসপিএমডি) সম্প্রসারণ নামে একটি পদ্ধতির মাধ্যমে শার্ডিং নির্দেশাবলী অনুসারে প্রোগ্রাম এবং টেনসরগুলি বিতরণ করার জন্য অন্তর্নিহিত কাঠামো (যেমন JAX) ব্যবহার করে। নতুন কেরাস 3 ডিস্ট্রিবিউশন এপিআই গাইডে আরও বিশদ দেখুন।

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

ডিস্ট্রিবিউশন এপিআই থেকে LayoutMap নির্দিষ্ট করে যে স্ট্রিং কী ব্যবহার করে কীভাবে ওজন এবং টেনসরগুলিকে শার্ড করা বা প্রতিলিপি করা উচিত, উদাহরণস্বরূপ, নীচের token_embedding/embeddings , যা টেনসর পাথের সাথে মেলে রেজেক্সের মতো আচরণ করা হয়। মিলিত টেনসরগুলি মডেলের মাত্রা (8 টিপিইউ) সহ শার্ড করা হয়; অন্যদের সম্পূর্ণরূপে প্রতিলিপি করা হবে.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel আপনাকে DeviceMesh সমস্ত devcies জুড়ে মডেলের ওজন বা অ্যাক্টিভেশন টেনসর শর্ড করতে দেয়। এই ক্ষেত্রে, Gemma 7B মডেলের কিছু ওজন উপরে সংজ্ঞায়িত layout_map অনুসারে 8 টি টিপিইউ চিপ জুড়ে শার্ড করা হয়েছে। এখন বিতরণ করা উপায়ে মডেলটি লোড করুন।

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

এখন যাচাই করুন যে মডেলটি সঠিকভাবে বিভাজন করা হয়েছে। একটি উদাহরণ হিসাবে decoder_block_1 নেওয়া যাক।

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

ফাইনটিউনিংয়ের আগে অনুমান

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

মডেলটি দেখার জন্য 90 এর দশকের দুর্দান্ত কমেডি সিনেমাগুলির একটি তালিকা তৈরি করে৷ এখন আমরা আউটপুট শৈলী পরিবর্তন করতে জেমা মডেলটি সূক্ষ্ম টিউন করি।

IMDB-এর সাথে ফাইনটিউন

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

নিম্ন র‌্যাঙ্ক অ্যাডাপ্টেশন (LoRA) ব্যবহার করে ফাইনটিউনিং করুন। LoRA হল একটি ফাইন-টিউনিং কৌশল যা মডেলের সম্পূর্ণ ওজন হিমায়িত করে এবং মডেলের মধ্যে অল্প সংখ্যক নতুন প্রশিক্ষনযোগ্য ওজন সন্নিবেশ করে ডাউনস্ট্রিম কাজের জন্য প্রশিক্ষণযোগ্য প্যারামিটারের সংখ্যা ব্যাপকভাবে হ্রাস করে। মূলত LoRA প্রশিক্ষণের জন্য 2টি ছোট লো-র‍্যাঙ্ক ম্যাট্রিস AxB দ্বারা বৃহত্তর পূর্ণ ওজনের ম্যাট্রিসগুলিকে পুনরায় প্যারামিটারাইজ করে এবং এই কৌশলটি প্রশিক্ষণকে আরও দ্রুত এবং আরও মেমরি-দক্ষ করে তোলে।

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

উল্লেখ্য যে LoRA সক্ষম করার ফলে প্রশিক্ষণযোগ্য প্যারামিটারের সংখ্যা উল্লেখযোগ্যভাবে হ্রাস পায়, 7 বিলিয়ন থেকে মাত্র 11 মিলিয়নে।

ফাইনটিউনিংয়ের পরে অনুমান

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

ফাইনটিউনিংয়ের পরে, মডেলটি মুভি রিভিউয়ের স্টাইল শিখেছে এবং এখন 90 এর দশকের কমেডি সিনেমার প্রেক্ষাপটে সেই স্টাইলে আউটপুট তৈরি করছে।

এরপর কি

এই টিউটোরিয়ালে, আপনি শিখেছেন কিভাবে কেরাসএনএলপি JAX ব্যাকএন্ড ব্যবহার করে শক্তিশালী TPU-তে বিতরণ করা পদ্ধতিতে IMDb ডেটাসেটে একটি জেমা মডেল ফিনটিউন করতে হয়। আর কী শিখতে হবে তার জন্য এখানে কয়েকটি পরামর্শ রয়েছে: