การปรับแต่งแบบกระจายด้วย Gemma โดยใช้ Keras

ดูใน ai.google.dev เรียกใช้ใน Google Colab เรียกใช้ใน Kaggle เปิดใน Vertex AI ดูซอร์สโค้ดใน GitHub

ภาพรวม

Gemma เป็นกลุ่มผลิตภัณฑ์โมเดลแบบเปิดที่ทันสมัยและน้ำหนักเบา ซึ่งสร้างขึ้นจากงานวิจัยและเทคโนโลยีที่ใช้สร้างโมเดล Google Gemini Gemma สามารถปรับแต่งเพิ่มเติมให้เหมาะกับความต้องการเฉพาะได้ แต่โมเดลภาษาขนาดใหญ่ เช่น Gemma อาจมีขนาดใหญ่มาก และบางโมเดลอาจไม่พอดีกับเครื่องเร่งความเร็วแบบ Sing สำหรับการปรับแต่งขั้นละเอียด ในกรณีนี้ การปรับแต่งมี 2 วิธีทั่วไป ดังนี้

  1. การปรับแต่งอย่างมีประสิทธิภาพ (PEFT) ของพารามิเตอร์ ซึ่งพยายามย่อขนาดโมเดลที่มีประสิทธิภาพโดยการลดความแม่นยำบางส่วน LoRA อยู่ในหมวดหมู่นี้ และบทแนะนำการปรับแต่งโมเดล Gemma ใน Keras ที่ใช้ LoRA จะสาธิตวิธีการปรับแต่งโมเดล Gemma 2B gemma_2b_en ด้วย LoRA โดยใช้ KerasNLP ใน GPU เดียว
  2. การปรับแต่งพารามิเตอร์อย่างละเอียดทั้งหมดด้วยการทำงานแบบขนานของโมเดล การทำงานแบบขนานของโมเดลจะกระจายน้ำหนักของโมเดลเดียวไปยังอุปกรณ์หลายเครื่องและเปิดใช้การปรับขนาดในแนวนอน ดูข้อมูลเพิ่มเติมเกี่ยวกับการฝึกอบรมแบบกระจายตัวได้ในคู่มือ Keras นี้

บทแนะนํานี้จะอธิบายการใช้ Keras กับแบ็กเอนด์ JAX เพื่อปรับแต่งโมเดล Gemma 7B ด้วย LoRA และการฝึกแบบกระจายแบบขนานของโมเดลใน Tensor Processing Unit (TPU) ของ Google โปรดทราบว่าคุณสามารถปิด LoRA ได้ในบทแนะนำนี้เพื่อให้การปรับแต่งพารามิเตอร์ทั้งหมดช้าลงแต่มีความแม่นยำมากขึ้น

การใช้ Accelerator

ในทางเทคนิคแล้ว คุณสามารถใช้ TPU หรือ GPU กับบทแนะนํานี้ได้

หมายเหตุเกี่ยวกับสภาพแวดล้อม TPU

Google มีผลิตภัณฑ์ 3 รายการที่มี TPU ดังนี้

  • Colab มี TPU v2 ให้ใช้งานฟรี ซึ่งเพียงพอสำหรับบทแนะนำนี้
  • Kaggle มี TPU v3 ให้ใช้งานฟรีและสามารถใช้กับบทแนะนำนี้ได้
  • Cloud TPU มี TPU v3 และรุ่นที่ใหม่กว่า วิธีตั้งค่ามีดังนี้
    1. สร้าง TPU VM ใหม่
    2. ตั้งค่าการส่งต่อพอร์ต SSH สำหรับพอร์ตเซิร์ฟเวอร์ Jupyter ที่ต้องการ
    3. ติดตั้ง Jupyter และเริ่มใช้งานใน TPU VM จากนั้นเชื่อมต่อกับ Colab ผ่าน "เชื่อมต่อกับรันไทม์ในเครื่อง"

หมายเหตุเกี่ยวกับการตั้งค่า GPU หลายตัว

แม้ว่าบทแนะนำนี้จะเน้นที่ Use Case ของ TPU แต่คุณก็ปรับให้เหมาะกับความต้องการของตัวเองได้ง่ายๆ หากมีเครื่องที่มี GPU หลายตัว

หากต้องการทํางานผ่าน Colab คุณสามารถจัดสรร VM แบบหลาย GPU สําหรับ Colab ได้โดยตรงผ่าน "เชื่อมต่อ GCE VM แบบกําหนดเอง" ในเมนู Colab Connect

เราจะมุ่งเน้นที่การใช้ TPU ฟรีจาก Kaggle ที่นี่

ก่อนเริ่มต้น

ข้อมูลเข้าสู่ระบบ Kaggle

โมเดล Gemma โฮสต์โดย Kaggle หากต้องการใช้ Gemma ให้ขอสิทธิ์เข้าถึงใน Kaggle โดยทำดังนี้

  • ลงชื่อเข้าใช้หรือลงทะเบียนที่ kaggle.com
  • เปิดการ์ดรุ่น Gemma แล้วเลือก"ขอสิทธิ์เข้าถึง"
  • กรอกแบบฟอร์มความยินยอมและยอมรับข้อกำหนดและเงื่อนไข

จากนั้นสร้างโทเค็น API เพื่อใช้ Kaggle API โดยทำดังนี้

  • เปิดการตั้งค่า Kaggle
  • เลือก"สร้างโทเค็นใหม่"
  • ดาวน์โหลดไฟล์ kaggle.json แล้ว มีข้อมูลเข้าสู่ระบบ Kaggle ของคุณ

เรียกใช้เซลล์ต่อไปนี้และป้อนข้อมูลเข้าสู่ระบบ Kaggle เมื่อระบบขอ

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

อีกวิธีหนึ่งคือตั้งค่า KAGGLE_USERNAME และ KAGGLE_KEY ในสภาพแวดล้อมหาก kagglehub.login() ไม่ทำงาน

การติดตั้ง

ติดตั้ง Keras และ KerasNLP ด้วยโมเดล Gemma

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

ตั้งค่าแบ็กเอนด์ Keras JAX

นำเข้า JAX และเรียกใช้การตรวจสอบความถูกต้องบน TPU Kaggle มีอุปกรณ์ TPUv3-8 ที่มีแกน TPU 8 แกนพร้อมหน่วยความจำ 16 GB ในแต่ละชิ้น

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

โหลดโมเดล

import keras
import keras_nlp

หมายเหตุเกี่ยวกับการฝึกด้วยความละเอียดแบบผสมใน GPU ของ NVIDIA

เมื่อฝึกใน GPU ของ NVIDIA คุณสามารถใช้ความแม่นยำแบบผสม (keras.mixed_precision.set_global_policy('mixed_bfloat16')) เพื่อเร่งการฝึกโดยให้ส่งผลต่อคุณภาพการฝึกน้อยที่สุด ในกรณีส่วนใหญ่ ขอแนะนำให้เปิดใช้ความแม่นยำแบบผสมเพราะจะช่วยประหยัดทั้งหน่วยความจำและเวลา อย่างไรก็ตาม โปรดทราบว่าเมื่อใช้กลุ่มที่มีขนาดเล็ก การใช้หน่วยความจําอาจเพิ่มขึ้น 1.5 เท่า (ระบบจะโหลดน้ำหนัก 2 ครั้ง โดยโหลดที่ความแม่นยำระดับครึ่งหนึ่งและความแม่นยำระดับเต็ม)

สําหรับการอนุมาน ความละเอียดครึ่งหนึ่ง (keras.config.set_floatx("bfloat16")) จะทํางานและประหยัดหน่วยความจําได้ ส่วนความละเอียดแบบผสมจะใช้ไม่ได้

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

หากต้องการโหลดโมเดลด้วยน้ำหนักและ Tensor ที่กระจายทั่วทั้ง TPU ให้สร้าง DeviceMesh ใหม่ก่อน DeviceMesh แสดงถึงคอลเล็กชันของอุปกรณ์ฮาร์ดแวร์ที่กำหนดค่าไว้สำหรับการประมวลผลแบบกระจาย และเปิดตัวใน Keras 3 โดยเป็นส่วนหนึ่งของ API การจัดจำหน่ายแบบรวม

Distribution API ช่วยให้ข้อมูลและโมเดลทำงานแบบขนานได้ ซึ่งช่วยให้ปรับขนาดโมเดลการเรียนรู้เชิงลึกในเครื่องเร่งความเร็วและโฮสต์หลายเครื่องได้อย่างมีประสิทธิภาพ โดยใช้ประโยชน์จากเฟรมเวิร์กพื้นฐาน (เช่น JAX) เพื่อกระจายโปรแกรมและเทนเซอร์ตามคำสั่งการแยกผ่านกระบวนการที่เรียกว่าการขยายโปรแกรมเดียวหลายข้อมูล (SPMD) ดูรายละเอียดเพิ่มเติมได้ในคู่มือ API ของ Keras 3 Distribution ฉบับใหม่

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap จาก API การแจกจ่ายจะระบุวิธีชาร์ดหรือจำลองน้ำหนักและ Tensor โดยใช้คีย์สตริง เช่น token_embedding/embeddings ด้านล่าง ซึ่งจะดำเนินการเหมือนกับนิพจน์ทั่วไปในการจับคู่เส้นทาง tensor Tensor ที่ตรงกันจะแบ่งออกเป็นกลุ่มตามมิติข้อมูลของโมเดล (TPU 8 ตัว) ส่วน Tensor อื่นๆ จะได้รับการทําซ้ำทั้งหมด

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel ให้คุณแบ่งน้ำหนักโมเดลหรือเทนเซอร์การเปิดใช้งานในอุปกรณ์ทั้งหมดใน DeviceMesh ได้ ในกรณีนี้ น้ำหนักโมเดล Gemma 7B บางรายการจะแบ่งออกเป็นกลุ่มในชิป TPU 8 ชิปตาม layout_map ที่กําหนดไว้ด้านบน คราวนี้ให้โหลดโมเดลแบบกระจาย

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

ตอนนี้ให้ตรวจสอบว่ามีการแบ่งพาร์ติชันโมเดลอย่างถูกต้องแล้ว มาดูตัวอย่างจาก decoder_block_1

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

การอนุมานก่อนการปรับแต่ง

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

โมเดลนี้จะสร้างรายการภาพยนตร์ตลกยอดเยี่ยมจากยุค 90 ให้รับชม ตอนนี้เราจะปรับแต่งโมเดล Gemma เพื่อเปลี่ยนรูปแบบเอาต์พุต

ปรับแต่งด้วย IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

ปรับแต่งโดยใช้การปรับอันดับต่ำ (LoRA) LoRA เป็นเทคนิคการปรับแต่งซึ่งช่วยลดจำนวนพารามิเตอร์ที่ฝึกได้สำหรับงานดาวน์สตรีมลงได้อย่างมากด้วยการตรึงน้ำหนักเต็มของโมเดลและใส่น้ำหนักใหม่ที่ฝึกได้ลงในโมเดลจำนวนที่น้อยลง พูดง่ายๆ ก็คือ LoRA จะเปลี่ยนพารามิเตอร์เมทริกซ์แบบเต็มน้ำหนักที่ใหญ่กว่าโดยใช้ AxB เมทริกซ์ระดับต่ำ 2 เมทริกซ์ที่เล็กกว่าเพื่อฝึก และเทคนิคนี้ทำให้การฝึกเร็วขึ้นและประหยัดหน่วยความจำได้มากยิ่งขึ้น

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

โปรดทราบว่าการเปิดใช้ LoRA จะลดจํานวนพารามิเตอร์ที่ฝึกได้อย่างมาก จาก 7 พันล้านเหลือเพียง 11 ล้าน

การอนุมานหลังการปรับแต่ง

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

หลังจากปรับแต่งอย่างละเอียด โมเดลก็ได้เรียนรู้สไตล์การรีวิวภาพยนตร์ และสร้างผลลัพธ์ในสไตล์ดังกล่าวในบริบทของภาพยนตร์ตลกยุค 90

ขั้นตอนถัดไป

ในบทแนะนํานี้ คุณได้เรียนรู้วิธีใช้แบ็กเอนด์ KerasNLP JAX เพื่อปรับแต่งโมเดล Gemma ในชุดข้อมูล IMDb ในลักษณะแบบกระจายบน TPU ที่มีประสิทธิภาพ ต่อไปนี้เป็นคำแนะนำเล็กๆ น้อยๆ เกี่ยวกับสิ่งอื่นที่ควรเรียนรู้