การปรับแต่งแบบกระจายด้วย Gemma โดยใช้ Keras

ดูใน ai.google.dev เรียกใช้ใน Kaggle เปิดใน Vertex AI ดูซอร์สบน GitHub

ภาพรวม

Gemma เป็นกลุ่มโมเดลแบบเปิดที่ทันสมัยและมีน้ำหนักเบา ซึ่งสร้างขึ้นจากการวิจัยและเทคโนโลยีที่ใช้ในการสร้างโมเดล Google Gemini คุณยังสามารถปรับ Gemma ให้ตรงตามความต้องการเฉพาะได้ แต่โมเดลภาษาขนาดใหญ่ เช่น Gemma อาจมีขนาดที่ใหญ่มากและโมเดลบางรายการอาจไม่พอดีกับตัวเร่งการร้องเพลงสำหรับการปรับแต่งอย่างละเอียด ในกรณีนี้ มีวิธีการทั่วไป 2 วิธีในการปรับแต่งตัวกรองเหล่านั้น

  1. พารามิเตอร์ Fine-Tuning (PEFT) อย่างมีประสิทธิภาพซึ่งจะพยายามลดขนาดโมเดลที่มีประสิทธิภาพโดยลดความแม่นยำบางส่วน LoRA จัดอยู่ในหมวดหมู่นี้ และบทแนะนำการปรับแต่ง Gemma ใน Keras ที่ใช้ LoRA จะสาธิตวิธีปรับแต่งโมเดล Gemma 2B gemma_2b_en ด้วย LoRA โดยใช้ KerasNLP บน GPU เดียว
  2. การปรับแต่งพารามิเตอร์แบบเต็มด้วยการทำงานพร้อมกันของโมเดล การโหลดพร้อมกันของโมเดลจะกระจายน้ำหนักของโมเดลเดียวในหลายๆ อุปกรณ์และเปิดใช้การปรับขนาดในแนวนอน ดูข้อมูลเพิ่มเติมเกี่ยวกับการฝึกอบรมแบบกระจายได้ในคู่มือ Keras นี้

บทแนะนำนี้จะอธิบายถึงการใช้ Keras กับแบ็กเอนด์ JAX เพื่อปรับแต่งโมเดล Gemma 7B ด้วย LoRA และการฝึกแบบกระจายโมเดลแบบกระจายตัวของโมเดล Tensor Processing Unit (TPU) ของ Google โปรดทราบว่าคุณปิด LoRA ได้ในบทแนะนำนี้สำหรับการปรับแต่งพารามิเตอร์แบบเต็มที่ช้ากว่าแต่แม่นยํามากกว่า

การใช้ Accelerator

ในทางเทคนิค คุณสามารถใช้ TPU หรือ GPU สำหรับบทแนะนำนี้

หมายเหตุเกี่ยวกับสภาพแวดล้อม TPU

Google มีผลิตภัณฑ์ 3 รายการที่มอบ TPU ดังนี้

  • Colab ให้ TPU v2 ซึ่งไม่เพียงพอสำหรับบทแนะนำนี้
  • Kaggle ให้บริการ TPU v3 ฟรีและใช้ได้กับบทแนะนำนี้
  • Cloud TPU มี TPU v3 และรุ่นใหม่กว่า วิธีการตั้งค่าวิธีหนึ่งคือ
    1. สร้าง TPU VM ใหม่
    2. ตั้งค่าการส่งต่อพอร์ต SSH สำหรับพอร์ตเซิร์ฟเวอร์ Jupyter ที่ต้องการ
    3. ติดตั้ง Jupyter และเริ่มการทำงานบน VM ของ TPU จากนั้นเชื่อมต่อกับ Colab ผ่าน "เชื่อมต่อกับรันไทม์ในเครื่อง"

หมายเหตุเกี่ยวกับการตั้งค่าหลาย GPU

แม้ว่าบทแนะนำนี้จะเน้นที่ Use Case ของ TPU แต่คุณก็ปรับเอามาให้เหมาะกับความต้องการของตัวเองได้ง่ายๆ ถ้ามีเครื่องแบบ Multi-GPU

หากต้องการใช้งาน Colab คุณก็จัดสรร VM หลาย GPU สำหรับ Colab ได้โดยตรงผ่าน "เชื่อมต่อกับ GCE VM ที่กำหนดเอง" ในเมนู Colab Connect

เราจะมุ่งเน้นที่การใช้ TPU ฟรีจาก Kaggle ที่นี่

ก่อนเริ่มต้น

ข้อมูลเข้าสู่ระบบ Kaggle

โมเดลของ Gemma โฮสต์โดย Kaggle หากต้องการใช้ Gemma ให้ขอสิทธิ์เข้าถึงใน Kaggle

  • ลงชื่อเข้าใช้หรือลงทะเบียนที่ kaggle.com
  • เปิดการ์ดโมเดล Gemma แล้วเลือก "ขอสิทธิ์เข้าถึง"
  • กรอกแบบฟอร์มความยินยอมและยอมรับข้อกำหนดและเงื่อนไข

จากนั้น หากต้องการใช้ Kaggle API ให้สร้างโทเค็น API โดยทำดังนี้

  • เปิดการตั้งค่า Kaggle
  • เลือก "Create New Token"
  • ระบบจะดาวน์โหลดไฟล์ kaggle.json มีข้อมูลเข้าสู่ระบบ Kaggle ของคุณ

เรียกใช้เซลล์ต่อไปนี้และป้อนข้อมูลเข้าสู่ระบบ Kaggle เมื่อระบบถาม

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

อีกวิธีหนึ่งคือการตั้งค่า KAGGLE_USERNAME และ KAGGLE_KEY ในสภาพแวดล้อมของคุณหาก kagglehub.login() ไม่ทำงานสำหรับคุณ

บริการติดตั้ง

ติดตั้ง Keras และ KerasNLP ด้วยโมเดล Gemma

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

ตั้งค่าแบ็กเอนด์ของ Keras JAX

นำเข้า JAX และเรียกใช้การตรวจสอบความปลอดภัยบน TPU Kaggle มีอุปกรณ์ TPUv3-8 ที่มีแกน TPU 8 แกน พร้อมหน่วยความจำ 16 GB ต่อเครื่อง

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

โหลดโมเดล

import keras
import keras_nlp

หมายเหตุเกี่ยวกับการฝึกความแม่นยำแบบผสมบน GPU NVIDIA

เมื่อฝึกบน GPU NVIDIA คุณสามารถใช้ความแม่นยำผสม (keras.mixed_precision.set_global_policy('mixed_bfloat16')) เพื่อเร่งความเร็วในการฝึกโดยมีผลกระทบต่อคุณภาพการฝึกน้อยที่สุด ในกรณีส่วนใหญ่ ขอแนะนำให้เปิดใช้ความแม่นยำแบบผสมเนื่องจากจะช่วยประหยัดทั้งหน่วยความจำและเวลา อย่างไรก็ตาม โปรดทราบว่าไฟล์ขนาดเล็กอาจทำให้การใช้งานหน่วยความจำเพิ่มขึ้น 1.5 เท่า (โหลดน้ำหนัก 2 ครั้งโดยใช้ความแม่นยำครึ่งหนึ่งและความแม่นยำเต็ม)

สำหรับการอนุมาน แบบครึ่งความแม่นยำ (keras.config.set_floatx("bfloat16")) จะทำงานและประหยัดหน่วยความจำได้ แต่จะใช้ความแม่นยำแบบผสมไม่ได้

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

หากต้องการโหลดโมเดลที่มีน้ำหนักและ Tensor ที่กระจายทั่ว TPU ให้สร้าง DeviceMesh ใหม่ก่อน DeviceMesh แสดงคอลเล็กชันอุปกรณ์ฮาร์ดแวร์ที่กำหนดค่าไว้สำหรับการคำนวณแบบกระจายและเปิดตัวใน Keras 3 เป็นส่วนหนึ่งของ Unified Distribution API

API การเผยแพร่ช่วยให้ข้อมูลและโมเดลดำเนินการพร้อมกัน ทำให้สามารถปรับขนาดโมเดลการเรียนรู้เชิงลึกได้อย่างมีประสิทธิภาพบน Accelerator และโฮสต์หลายรายการ บริษัทใช้ประโยชน์จากเฟรมเวิร์กพื้นฐาน (เช่น JAX) เพื่อแจกจ่ายโปรแกรมและ Tensor ตามคำสั่งชาร์ดดิ้งผ่านกระบวนการที่เรียกว่าโปรแกรมเดี่ยว การขยายข้อมูลหลายรายการ (SPMD) ดูรายละเอียดเพิ่มเติมได้ในคู่มือ Keras 3 Distribution API ใหม่

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap จาก API การกระจายจะระบุวิธีที่ควรชาร์ดหรือจำลองน้ำหนักและ tensor โดยใช้คีย์สตริง เช่น token_embedding/embeddings ด้านล่าง ซึ่งจะมีการดำเนินการเหมือนกับ regex เพื่อจับคู่เส้นทาง tensor Tensor ที่ตรงกันจะถูกชาร์ดด้วยมิติข้อมูลของโมเดล (8 TPU) และจะมีการจำลองตัวเองอย่างเต็มรูปแบบ

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel ช่วยให้คุณชาร์ดน้ำหนักของโมเดลหรือ Tensor การเปิดใช้งานในอุปกรณ์ทั้งหมดใน DeviceMesh ได้ ในกรณีนี้ น้ำหนักของโมเดล Gemma 7B บางส่วนจะมีการชาร์ดในชิป TPU 8 ชิปตาม layout_map ที่กำหนดไว้ข้างต้น ตอนนี้ ให้โหลดโมเดลด้วยวิธีการแบบกระจาย

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

ตอนนี้ให้ยืนยันว่าโมเดลได้รับการแบ่งพาร์ติชันอย่างถูกต้องแล้ว ลองดู decoder_block_1 เป็นตัวอย่าง

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

การอนุมานก่อนปรับแต่ง

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

โดยโมเดลดังกล่าวจะสร้างรายชื่อภาพยนตร์ตลกยอดเยี่ยมจากยุค 90 ให้รับชม ตอนนี้เราปรับแต่งโมเดล Gemma เพื่อเปลี่ยนรูปแบบเอาต์พุต

ปรับแต่งด้วย IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

ปรับแต่งโดยใช้ Low Rank Adaptation (LoRA) LoRA เป็นเทคนิคการปรับแต่งที่ช่วยลดจำนวนพารามิเตอร์ที่ฝึกได้สำหรับงานปลายทางได้เป็นอย่างมาก โดยการตรึงน้ำหนักทั้งหมดของโมเดลแล้วแทรกน้ำหนักที่ฝึกได้ใหม่ลงในโมเดลจำนวนน้อยลง โดยพื้นฐานแล้ว LoRA จะปรับพารามิเตอร์เมทริกซ์น้ำหนักเต็มที่มีขนาดใหญ่ขึ้นอีกครั้งด้วย AxB ระดับต่ำ 2 รายการที่เล็กลงเพื่อฝึก และเทคนิคนี้ทำให้การฝึกเร็วขึ้นมากและมีประสิทธิภาพหน่วยความจำมากขึ้น

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

โปรดทราบว่าการเปิดใช้ LoRA จะลดจำนวนพารามิเตอร์ที่ฝึกได้อย่างมีนัยสำคัญ จาก 7 พันล้านเป็น 11 ล้านครั้ง

การอนุมานหลังการปรับแต่ง

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

หลังจากปรับแต่ง นายแบบ/นางแบบก็ได้เรียนรู้สไตล์การรีวิวภาพยนตร์และตอนนี้ก็ได้สร้างผลลัพธ์ในรูปแบบนั้นในบริบทของภาพยนตร์ตลกยุค 90

ขั้นตอนถัดไป

ในบทแนะนำนี้ คุณได้เรียนรู้วิธีใช้แบ็กเอนด์ KerasNLP JAX เพื่อปรับแต่งโมเดล Gemma บนชุดข้อมูล IMDb ด้วยวิธีที่มีการกระจายบน TPU ที่มีประสิทธิภาพ คำแนะนำ 2-3 ข้อสำหรับสิ่งอื่นๆ ที่ควรเรียนรู้มีดังนี้