ดูใน ai.google.dev | เรียกใช้ใน Google Colab | เรียกใช้ใน Kaggle | เปิดใน Vertex AI | ดูแหล่งที่มาใน GitHub |
ภาพรวม
Gemma คือตระกูลโมเดลแบบเปิดที่ทันสมัยน้ำหนักเบา สร้างขึ้นจากการวิจัยและเทคโนโลยีที่ใช้ในการสร้างโมเดล Google Gemini นอกจากนี้ยังสามารถปรับแต่ง Gemma เพิ่มเติมให้เหมาะกับความต้องการที่เฉพาะเจาะจงได้ แต่โมเดลภาษาขนาดใหญ่ เช่น Gemma อาจมีขนาดใหญ่มาก และบางโมเดลอาจไม่เหมาะกับตัวเร่งเสียงสำหรับการปรับแต่ง ในกรณีนี้ มีวิธีการทั่วไป 2 วิธีในการปรับแต่ง
- การปรับแต่งอย่างมีประสิทธิภาพ (PEFT) ของพารามิเตอร์ ซึ่งพยายามย่อขนาดโมเดลที่มีประสิทธิภาพโดยการลดความแม่นยำบางส่วน LoRA อยู่ในหมวดหมู่นี้ และบทแนะนำการปรับแต่งโมเดล Gemma ใน Keras โดยใช้ LoRA จะสาธิตวิธีการปรับแต่งโมเดล Gemma 2B
gemma_2b_en
ด้วย LoRA โดยใช้ KerasNLP ใน GPU เดียว - การปรับแต่งพารามิเตอร์ทั้งหมดด้วยการทำงานพร้อมกันของโมเดล การทำงานพร้อมกันของโมเดลจะกระจายน้ำหนักของโมเดลเดียวในอุปกรณ์หลายเครื่องและเปิดใช้การปรับขนาดในแนวนอน ดูข้อมูลเพิ่มเติมเกี่ยวกับการฝึกอบรมแบบกระจายตัวได้ในคู่มือ Keras นี้
บทแนะนำนี้จะแนะนำวิธีใช้ Keras กับแบ็กเอนด์ JAX เพื่อปรับแต่งโมเดล Gemma 7B ด้วย LoRA และการฝึกโมเดลแบบกระจายตัวใน Tensor Processing Unit (TPU) ของ Google โปรดทราบว่าคุณสามารถปิด LoRA ได้ในบทแนะนำนี้เพื่อให้การปรับแต่งพารามิเตอร์ทั้งหมดช้าลงแต่มีความแม่นยำมากขึ้น
การใช้ Accelerator
ในทางเทคนิค คุณจะใช้ TPU หรือ GPU ก็ได้สำหรับบทแนะนำนี้
หมายเหตุเกี่ยวกับสภาพแวดล้อม TPU
Google มีผลิตภัณฑ์ 3 รายการที่มี TPU ดังนี้
- Colab มี TPU v2 ให้ใช้งานแบบไม่เสียค่าใช้จ่าย ซึ่งเพียงพอสำหรับบทแนะนำนี้
- Kaggle มี TPU v3 ให้ใช้งานฟรีและใช้สำหรับบทแนะนำนี้ได้เช่นกัน
- Cloud TPU มี TPU v3 และรุ่นใหม่กว่า วิธีการตั้งค่าอย่างหนึ่งคือ
- สร้าง TPU VM ใหม่
- ตั้งค่าการส่งต่อพอร์ต SSH สำหรับพอร์ตเซิร์ฟเวอร์ Jupyter ที่ต้องการ
- ติดตั้ง Jupyter และเริ่มต้นบน VM ของ TPU จากนั้นเชื่อมต่อกับ Colab ผ่าน "เชื่อมต่อกับรันไทม์ในเครื่อง"
หมายเหตุเกี่ยวกับการตั้งค่า Multi-GPU
แม้ว่าบทแนะนำนี้จะเน้นไปที่กรณีการใช้งาน TPU แต่คุณก็สามารถนำไปปรับใช้ตามความต้องการได้อย่างง่ายดายหากมีเครื่องแบบ Multi-GPU
หากต้องการทํางานผ่าน Colab คุณก็จัดสรร VM แบบ Multi-GPU สำหรับ Colab โดยตรงผ่าน "เชื่อมต่อกับ GCE VM ที่กำหนดเอง" ได้ด้วย ในเมนู Colab Connect
เราจะมุ่งเน้นที่การใช้ TPU ฟรีจาก Kaggle ที่นี่
ก่อนเริ่มต้น
ข้อมูลเข้าสู่ระบบ Kaggle
โมเดล Gemma โฮสต์โดย Kaggle หากต้องการใช้ Gemma ให้ขอสิทธิ์เข้าถึง Kaggle ดังนี้
- ลงชื่อเข้าใช้หรือลงทะเบียนที่ kaggle.com
- เปิดการ์ดโมเดล Gemma แล้วเลือก "ขอสิทธิ์เข้าถึง"
- กรอกแบบฟอร์มความยินยอมและยอมรับข้อกำหนดและเงื่อนไข
จากนั้น หากต้องการใช้ Kaggle API ให้สร้างโทเค็น API โดยทำดังนี้
- เปิดการตั้งค่า Kaggle
- เลือก "สร้างโทเค็นใหม่"
- ดาวน์โหลดไฟล์
kaggle.json
แล้ว ซึ่งมีข้อมูลเข้าสู่ระบบ Kaggle ของคุณ
เรียกใช้เซลล์ต่อไปนี้และป้อนข้อมูลเข้าสู่ระบบ Kaggle เมื่อระบบขอ
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
อีกวิธีหนึ่งคือการตั้งค่า KAGGLE_USERNAME และ KAGGLE_KEY ในสภาพแวดล้อมของคุณ หาก kagglehub.login() ใช้งานไม่ได้
การติดตั้ง
ติดตั้ง Keras และ KerasNLP ด้วยโมเดล Gemma
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
ตั้งค่าแบ็กเอนด์ Keras JAX
นำเข้า JAX และเรียกใช้การตรวจสอบความถูกต้องบน TPU Kaggle นำเสนออุปกรณ์ TPUv3-8 ซึ่งมี TPU 8 แกนพร้อมหน่วยความจำ 16 GB ในแต่ละชิ้น
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
โหลดโมเดล
import keras
import keras_nlp
หมายเหตุเกี่ยวกับการฝึกความแม่นยำแบบผสมใน GPU ของ NVIDIA
เมื่อฝึกใน GPU ของ NVIDIA คุณสามารถใช้ความแม่นยำแบบผสม (keras.mixed_precision.set_global_policy('mixed_bfloat16')
) เพื่อเร่งการฝึกโดยให้ส่งผลต่อคุณภาพการฝึกน้อยที่สุด ในกรณีส่วนใหญ่ ขอแนะนำให้เปิดใช้ความแม่นยำแบบผสมเพราะจะช่วยประหยัดทั้งหน่วยความจำและเวลา อย่างไรก็ตาม โปรดทราบว่าเมื่อมีไฟล์ขนาดเล็ก อาจทำให้มีการใช้หน่วยความจำสูงเกินจริงได้ถึง 1.5 เท่า (น้ำหนักจะโหลด 2 ครั้ง โดยใช้ความแม่นยำครึ่งหนึ่งและแม่นยำสูงสุด)
สำหรับการอนุมาน ความแม่นยำครึ่งหนึ่ง (keras.config.set_floatx("bfloat16")
) จะทำงานและประหยัดหน่วยความจำได้ แต่ไม่สามารถใช้ความแม่นยําแบบผสม
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
หากต้องการโหลดโมเดลด้วยน้ำหนักและ Tensor ที่กระจายทั่วทั้ง TPU ให้สร้าง DeviceMesh
ใหม่ก่อน DeviceMesh
แสดงถึงคอลเล็กชันของอุปกรณ์ฮาร์ดแวร์ที่กำหนดค่าไว้สำหรับการประมวลผลแบบกระจาย และเปิดตัวใน Keras 3 โดยเป็นส่วนหนึ่งของ API การจัดจำหน่ายแบบรวม
API การกระจายช่วยให้มีข้อมูลและโมเดลทำงานพร้อมกัน ทำให้สามารถปรับขนาดโมเดลการเรียนรู้เชิงลึกได้อย่างมีประสิทธิภาพใน Accelerator และโฮสต์หลายรายการ ซึ่งจะใช้ประโยชน์จากเฟรมเวิร์กพื้นฐาน (เช่น JAX) เพื่อกระจายโปรแกรมและ Tensor ตามคำแนะนำการชาร์ดดิ้งผ่านกระบวนการที่เรียกว่าการขยายโปรแกรมเดียวหรือการขยายข้อมูลหลายรายการ (SPMD) ดูรายละเอียดเพิ่มเติมในคู่มือ API การจัดจำหน่าย Keras 3 ใหม่
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
LayoutMap
จาก API การกระจายจะระบุวิธีชาร์ดหรือจำลองน้ำหนักและ Tensor โดยใช้คีย์สตริง เช่น token_embedding/embeddings
ด้านล่าง ซึ่งระบบจะถือว่าเหมือนนิพจน์ทั่วไปในการจับคู่เส้นทาง Tensor Tensor ที่ตรงกันจะมีการชาร์ดที่มีขนาดโมเดล (TPU 8 ตัว) รายการอื่นๆ จะถูกคัดลอกทั้งหมด
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
ModelParallel
ให้คุณชาร์ดน้ำหนักโมเดลหรือ Tensor การเปิดใช้งานในทุกแพลตฟอร์มใน DeviceMesh
ได้ ในกรณีนี้ น้ำหนักโมเดล Gemma 7B บางส่วนจะมีการชาร์ดในชิป TPU 8 ชิปตาม layout_map
ที่ระบุข้างต้น คราวนี้ให้โหลดโมเดลแบบกระจาย
model_parallel = keras.distribution.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
ในขั้นตอนนี้ให้ตรวจสอบว่าโมเดลมีการแบ่งพาร์ติชันอย่างถูกต้อง ลองดู decoder_block_1
เป็นตัวอย่าง
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
การอนุมานก่อนการปรับแต่ง
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
โมเดลนี้จะสร้างรายการภาพยนตร์ตลกยอดเยี่ยมจากยุค 90 ให้รับชม ในตอนนี้ เราได้ปรับแต่งโมเดล Gemma เพื่อเปลี่ยนรูปแบบเอาต์พุต
ปรับแต่งด้วย IMDB
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
ทำการปรับแต่งโดยใช้การปรับอันดับต่ำ (LoRA) LoRA เป็นเทคนิคการปรับแต่งซึ่งช่วยลดจำนวนพารามิเตอร์ที่ฝึกได้สำหรับงานดาวน์สตรีมลงได้อย่างมากด้วยการตรึงน้ำหนักเต็มของโมเดลและใส่น้ำหนักใหม่ที่ฝึกได้ลงในโมเดลจำนวนที่น้อยลง พูดง่ายๆ ก็คือ LoRA จะเปลี่ยนพารามิเตอร์เมทริกซ์แบบเต็มน้ำหนักที่ใหญ่กว่าโดยใช้ AxB เมทริกซ์ระดับต่ำ 2 เมทริกซ์ที่เล็กกว่าเพื่อฝึก และเทคนิคนี้ทำให้การฝึกเร็วขึ้นและประหยัดหน่วยความจำได้มากยิ่งขึ้น
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
โปรดทราบว่าการเปิดใช้ LoRA จะลดจำนวนพารามิเตอร์ที่ฝึกได้ลงได้อย่างมาก จาก 7 พันล้านรายการเหลือเพียง 11 ล้านรายการ
การอนุมานหลังการปรับแต่ง
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
หลังจากปรับแต่งอย่างละเอียด โมเดลก็ได้เรียนรู้สไตล์การรีวิวภาพยนตร์ และสร้างผลลัพธ์ในสไตล์ดังกล่าวในบริบทของภาพยนตร์ตลกยุค 90
ขั้นตอนถัดไป
ในบทแนะนำนี้ คุณได้เรียนรู้วิธีใช้แบ็กเอนด์ KerasNLP JAX เพื่อปรับแต่งโมเดล Gemma บนชุดข้อมูล IMDb ในลักษณะที่มีการกระจายบน TPU ที่มีประสิทธิภาพ ต่อไปนี้เป็นคำแนะนำเล็กๆ น้อยๆ เกี่ยวกับสิ่งอื่นที่ควรเรียนรู้
- ดูวิธีเริ่มต้นใช้งาน Keras Gemma
- ดูวิธีปรับแต่งโมเดล Gemma ใน GPU