Keras का इस्तेमाल करके, जेमा के साथ डिस्ट्रिब्यूटेड ट्यूनिंग

ai.google.dev पर देखें Google Colab में चलाना Kaggle में चलाएं Vertex AI में खोलें GitHub पर सोर्स देखना

खास जानकारी

Gemma एक लाइटवेट और बेहतरीन ओपन मॉडल है. इसे Google Gemini के मॉडल बनाने में इस्तेमाल की गई रिसर्च और टेक्नोलॉजी का इस्तेमाल करके बनाया गया है. Gemma को खास ज़रूरतों के हिसाब से बेहतर बनाया जा सकता है. हालांकि, जेमा जैसे बड़े लैंग्वेज मॉडल का साइज़ बहुत बड़ा हो सकता है. साथ ही, हो सकता है कि इनमें से कुछ मॉडल को फ़ाइन-ट्यून करने के लिए, सिंगल ऐक्सेलरेटर का इस्तेमाल न किया जा सके. इस मामले में, उन्हें बेहतर बनाने के लिए दो सामान्य तरीके हैं:

  1. पैरामीटर एफ़िशिएंट फ़ाइन-ट्यूनिंग (PEFT), जो कुछ फ़िडेलिटी का त्याग करके, मॉडल के साइज़ को छोटा करने की कोशिश करता है. LoRA इस कैटगरी में आता है. LoRA का इस्तेमाल करके, Keras में Gemma मॉडल को फ़ाइन-ट्यून करना ट्यूटोरियल में, एक ही जीपीयू पर KerasNLP का इस्तेमाल करके, LoRA की मदद से Gemma 2B मॉडल gemma_2b_en को फ़ाइन-ट्यून करने का तरीका बताया गया है.
  2. मॉडल के पैरलल प्रोसेसिंग के साथ, पैरामीटर को पूरी तरह से फ़ाइन-ट्यून करना. मॉडल पैरलललिज़्म, एक मॉडल के वेट को कई डिवाइसों पर बांटता है और हॉरिज़ॉन्टल स्केलिंग की सुविधा चालू करता है. डिस्ट्रिब्यूटेड ट्रेनिंग के बारे में ज़्यादा जानने के लिए, Keras गाइड पढ़ें.

इस ट्यूटोरियल में, JAX बैकएंड के साथ Keras का इस्तेमाल करने का तरीका बताया गया है. इससे, Google की Tensor Processing Unit (TPU) पर LoRA और मॉडल-पैरैललिज़्म डिस्ट्रिब्यूटेड ट्रेनिंग की मदद से, Gemma 7B मॉडल को बेहतर बनाया जा सकता है. ध्यान दें कि इस ट्यूटोरियल में LoRA को बंद किया जा सकता है, ताकि पूरी पैरामीटर ट्यूनिंग धीमी हो, लेकिन ज़्यादा सटीक हो.

ऐक्सेलरेटर का इस्तेमाल करना

तकनीकी तौर पर, इस ट्यूटोरियल के लिए TPU या GPU, दोनों में से किसी का भी इस्तेमाल किया जा सकता है.

TPU एनवायरमेंट के बारे में जानकारी

Google के तीन प्रॉडक्ट में टीपीयू की सुविधा मिलती है:

  • Colab, TPU v2 को बिना किसी शुल्क के उपलब्ध कराता है. यह ट्यूटोरियल के लिए काफ़ी है.
  • Kaggle, TPU v3 को बिना किसी शुल्क के उपलब्ध कराता है. यह ट्यूटोरियल के लिए भी काम करता है.
  • Cloud TPU, TPU v3 और नई जनरेशन के TPU उपलब्ध कराता है. इसे सेट अप करने का एक तरीका यह है:
    1. नया TPU VM बनाना
    2. अपने Jupyter सर्वर पोर्ट के लिए एसएसएच पोर्ट फ़ॉरवर्डिंग सेट अप करें
    3. Jupyter इंस्टॉल करें और उसे TPU VM पर शुरू करें. इसके बाद, "लोकल रनटाइम से कनेक्ट करें" के ज़रिए Colab से कनेक्ट करें

एक से ज़्यादा जीपीयू के सेटअप के बारे में जानकारी

इस ट्यूटोरियल में, टीपीयू के इस्तेमाल के उदाहरण पर फ़ोकस किया गया है. हालांकि, अगर आपके पास कई जीपीयू वाली मशीन है, तो इसे अपनी ज़रूरतों के हिसाब से आसानी से बदला जा सकता है.

अगर आपको Colab का इस्तेमाल करना है, तो Colab के लिए कई जीपीयू वाला वर्चुअल मशीन भी उपलब्ध कराया जा सकता है. इसके लिए, Colab के 'कनेक्ट करें' मेन्यू में जाकर, "कस्टम GCE (जीसीई) वर्चुअल मशीन से कनेक्ट करें" पर जाएं.

हम यहां Kaggle के मुफ़्त TPU का इस्तेमाल करने पर फ़ोकस करेंगे.

शुरू करने से पहले

Kaggle के क्रेडेंशियल

Gemma मॉडल, Kaggle पर होस्ट किए जाते हैं. Gemma का इस्तेमाल करने के लिए, Kaggle पर ऐक्सेस का अनुरोध करें:

  • kaggle.com पर साइन इन करें या रजिस्टर करें
  • Gemma मॉडल कार्ड खोलें और "ऐक्सेस का अनुरोध करें" चुनें
  • सहमति वाला फ़ॉर्म भरें और नियम और शर्तें स्वीकार करें

इसके बाद, Kaggle API का इस्तेमाल करने के लिए, एपीआई टोकन बनाएं:

  • Kaggle की सेटिंग खोलें
  • "नया टोकन बनाएं" चुनें
  • kaggle.json फ़ाइल डाउनलोड हो गई. इसमें आपके Kaggle क्रेडेंशियल शामिल हैं

नीचे दी गई सेल को चलाएं और पूछे जाने पर अपने Kaggle क्रेडेंशियल डालें.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

अगर kagglehub.login() आपके लिए काम नहीं करता है, तो अपने वातावरण में KAGGLE_USERNAME और KAGGLE_KEY को सेट करना भी एक दूसरा तरीका है.

इंस्टॉल करना

Gemma मॉडल के साथ Keras और KerasNLP इंस्टॉल करें.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX बैकएंड सेट अप करना

JAX इंपोर्ट करें और TPU पर जांच करें. Kaggle, TPUv3-8 डिवाइसों की सुविधा देता है. इनमें आठ TPU कोर होते हैं और हर डिवाइस में 16 जीबी मेमोरी होती है.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

मॉडल लोड करना

import keras
import keras_nlp

NVIDIA जीपीयू पर अलग-अलग सटीक ट्रेनिंग के बारे में जानकारी

NVIDIA जीपीयू पर ट्रेनिंग करते समय, ट्रेनिंग की क्वालिटी पर कम से कम असर डालते हुए, ट्रेनिंग की स्पीड बढ़ाने के लिए, मिक्स्ड प्रिसीज़न (keras.mixed_precision.set_global_policy('mixed_bfloat16')) का इस्तेमाल किया जा सकता है. ज़्यादातर मामलों में, हमारा सुझाव है कि आप 'मिक्स्ड प्रिसीज़न' को चालू करें, क्योंकि इससे मेमोरी और समय दोनों की बचत होती है. हालांकि, ध्यान रखें कि छोटे बैच साइज़ में, मेमोरी के इस्तेमाल में 1.5 गुना बढ़ोतरी हो सकती है. इसका मतलब है कि वेट को दो बार लोड किया जाएगा, एक बार आधी सटीक जानकारी के साथ और एक बार पूरी सटीक जानकारी के साथ.

अनुमान लगाने के लिए, हफ़्फ़-प्रिसिशन (keras.config.set_floatx("bfloat16")) का इस्तेमाल किया जा सकता है. इससे मेमोरी की बचत होती है. हालांकि, मिक्स्ड-प्रिसिशन का इस्तेमाल नहीं किया जा सकता.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

TPU में डिस्ट्रिब्यूट किए गए वज़न और टेंसर के साथ मॉडल लोड करने के लिए, पहले नया DeviceMesh बनाएं. DeviceMesh, डिस्ट्रिब्यूटेड कंप्यूटेशन के लिए कॉन्फ़िगर किए गए हार्डवेयर डिवाइसों के कलेक्शन को दिखाता है. इसे यूनिफ़ाइड डिस्ट्रिब्यूशन एपीआई के हिस्से के तौर पर, Keras 3 में पेश किया गया था.

डिस्ट्रिब्यूशन एपीआई, डेटा और मॉडल के पैरलल प्रोसेसिंग की सुविधा देता है. इससे, कई ऐक्सेलरेटर और होस्ट पर डीप लर्निंग मॉडल को बेहतर तरीके से स्केल किया जा सकता है. यह प्रोग्राम और टेंसर को, शर्डिंग निर्देशों के हिसाब से डिस्ट्रिब्यूट करने के लिए, मौजूदा फ़्रेमवर्क (जैसे, JAX) का इस्तेमाल करता है. इसके लिए, एक प्रोग्राम, कई डेटा (एसपीएमडी) एक्सपैंशन नाम की प्रोसेस का इस्तेमाल किया जाता है. ज़्यादा जानकारी के लिए, Keras 3 डिस्ट्रिब्यूशन एपीआई की नई गाइड देखें.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

डिस्ट्रिब्यूशन एपीआई का LayoutMap, स्ट्रिंग बटन का इस्तेमाल करके यह तय करता है कि वेट और टेंसर को कैसे शेयर किया जाए या डुप्लीकेट किया जाए. उदाहरण के लिए, नीचे दिया गया token_embedding/embeddings, जिसे टेंसर पाथ से मैच करने के लिए रेगुलर एक्सप्रेशन की तरह माना जाता है. मैच होने वाले टेंसर को मॉडल डाइमेंशन (8 टीपीयू) के साथ शेयर किया जाता है. अन्य टेंसर की पूरी तरह से कॉपी बनाई जाएगी.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel की मदद से, DeviceMesh पर सभी डिवाइसों में मॉडल के वेट या ऐक्टिवेशन टेंसर को अलग-अलग हिस्सों में बांटा जा सकता है. इस मामले में, Gemma 7B मॉडल के कुछ वेट को ऊपर बताए गए layout_map के मुताबिक, आठ टीपीयू चिप में बांटा गया है. अब मॉडल को डिस्ट्रिब्यूट किए गए तरीके से लोड करें.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

अब पुष्टि करें कि मॉडल को सही तरीके से पार्टिशन किया गया है. आइए, decoder_block_1 को उदाहरण के तौर पर लेते हैं.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

फ़ाइन-ट्यून करने से पहले अनुमान लगाना

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

मॉडल, 90 के दशक की बेहतरीन कॉमेडी फ़िल्मों की सूची जनरेट करता है. अब हम आउटपुट स्टाइल बदलने के लिए, Gemma मॉडल को बेहतर बनाते हैं.

आईएमडीबी की मदद से फ़ाइन ट्यून करें

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

लो रैंक अडैप्टेशन (LoRA) का इस्तेमाल करके, मॉडल को बेहतर बनाएं. LoRA, फ़ाइन-ट्यूनिंग की एक तकनीक है. यह मॉडल के पूरे वेट को फ़्रीज़ करके और मॉडल में ट्रेन किए जा सकने वाले कम नए वेट डालकर, डाउनस्ट्रीम टास्क के लिए ट्रेन किए जा सकने वाले पैरामीटर की संख्या को काफ़ी कम कर देती है. आम तौर पर, LoRA ट्रेनिंग के लिए, बड़े फ़ुल वेट मैट्रिक्स को दो छोटे लो-रैंक मैट्रिक्स AxB में बदल देता है. इस तकनीक से, ट्रेनिंग बहुत तेज़ी से और ज़्यादा मेमोरी के साथ की जा सकती है.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

ध्यान दें कि LoRA को चालू करने से, ट्रेन किए जा सकने वाले पैरामीटर की संख्या काफ़ी कम हो जाती है. यह संख्या सात अरब से घटकर सिर्फ़ 11 मिलियन हो जाती है.

फ़ाइन-ट्यून करने के बाद अनुमान लगाना

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

बेहतर बनाने के बाद, मॉडल ने फ़िल्मों की समीक्षा करने की स्टाइल सीखी. अब यह 90 के दशक की कॉमेडी फ़िल्मों के लिए, इसी स्टाइल के हिसाब से कॉन्टेंट जनरेट कर रहा है.

आगे क्या करना है

इस ट्यूटोरियल में, आपको KerasNLP JAX बैकएंड का इस्तेमाल करके, IMDb डेटासेट पर Gemma मॉडल को बेहतर बनाने का तरीका पता चला. इसके लिए, ज़्यादा बेहतर TPUs का इस्तेमाल किया गया. यहां कुछ और चीज़ों के बारे में जानने के लिए सुझाव दिए गए हैं: