Điều chỉnh phân phối với Gemma bằng Keras

Xem trên ai.google.dev Chạy trong Google Colab Chạy ở Kaggle Mở trong Vertex AI Xem nguồn trên GitHub

Tổng quan

Gemma là một dòng mô hình mở, gọn nhẹ và tiên tiến được xây dựng từ nghiên cứu và công nghệ dùng để tạo các mô hình Gemini của Google. Gemma có thể được tinh chỉnh thêm để phù hợp với nhu cầu cụ thể. Tuy nhiên, các Mô hình ngôn ngữ lớn như Gemma có thể có kích thước rất lớn và một số Mô hình có thể không vừa với máy tăng tốc hát để tinh chỉnh. Trong trường hợp này, có hai phương pháp chung để tinh chỉnh chúng:

  1. Tinh chỉnh hiệu quả thông số (PEFT), tìm cách giảm kích thước mô hình hiệu quả bằng cách hy sinh một chút độ trung thực. LoRA thuộc danh mục này và hướng dẫn Tinh chỉnh mô hình Gemma trong Keras bằng LoRA minh hoạ cách tinh chỉnh mô hình Gemma 2B gemma_2b_en bằng LoRA bằng KerasNLP trên một GPU.
  2. Tinh chỉnh tham số đầy đủ bằng tính năng song song mô hình. Tính năng song song của mô hình giúp phân phối trọng số của một mô hình trên nhiều thiết bị và cho phép điều chỉnh theo tỷ lệ theo chiều ngang. Bạn có thể tìm hiểu thêm về chương trình đào tạo phân tán trong hướng dẫn về Kernel này.

Hướng dẫn này sẽ hướng dẫn bạn cách sử dụng Keras với phần phụ trợ JAX để tinh chỉnh mô hình Gemma 7B bằng LoRA và chương trình huấn luyện phân tán hiệu ứng biến đổi mô hình trên Đơn vị xử lý Tensor (TPU) của Google. Lưu ý là bạn có thể tắt LoRA trong hướng dẫn này để điều chỉnh thông số đầy đủ chậm hơn nhưng chính xác hơn.

Sử dụng trình tăng tốc

Về mặt kỹ thuật, bạn có thể sử dụng TPU hoặc GPU cho hướng dẫn này.

Ghi chú trên môi trường TPU

Google có 3 sản phẩm cung cấp TPU:

  • Colab cung cấp TPU phiên bản 2 miễn phí và hướng dẫn này là đủ.
  • Kaggle cung cấp TPU phiên bản 3 miễn phí và chúng cũng phù hợp với hướng dẫn này.
  • Cloud TPU cung cấp TPU phiên bản 3 và các thế hệ mới hơn. Một cách để thiết lập là:
    1. Tạo một máy ảo TPU mới
    2. Thiết lập tính năng chuyển tiếp cổng SSH cho cổng máy chủ Jupyter mà bạn dự định sử dụng
    3. Cài đặt Jupyter và khởi động ứng dụng này trên máy ảo TPU, sau đó kết nối với Colab thông qua tính năng "Kết nối với một môi trường thời gian chạy cục bộ"

Lưu ý về cách thiết lập nhiều GPU

Mặc dù hướng dẫn này tập trung vào trường hợp sử dụng TPU, nhưng bạn có thể dễ dàng điều chỉnh trường hợp này cho phù hợp với nhu cầu của riêng mình nếu bạn có máy sử dụng nhiều GPU.

Nếu muốn làm việc qua Colab, bạn cũng có thể cung cấp trực tiếp một máy ảo nhiều GPU cho Colab bằng cách "Kết nối với một máy ảo GCE tuỳ chỉnh" trong trình đơn Colab Connect.

Chúng ta sẽ tập trung vào việc sử dụng TPU miễn phí của Kaggle.

Trước khi bắt đầu

Thông tin đăng nhập Kaggle

Các mô hình Gemma do Kaggle lưu trữ. Để sử dụng Gemma, hãy yêu cầu quyền truy cập trên Kaggle:

  • Đăng nhập hoặc đăng ký tại kaggle.com
  • Mở thẻ mô hình Gemma rồi chọn "Yêu cầu quyền truy cập"
  • Hoàn tất biểu mẫu đồng ý rồi chấp nhận các điều khoản và điều kiện

Sau đó, để sử dụng Kaggle API, hãy tạo một mã thông báo API:

  • Mở phần cài đặt Kaggle
  • Chọn "Tạo mã thông báo mới"
  • Đang tải một tệp kaggle.json xuống. Lớp này chứa thông tin đăng nhập Kaggle của bạn

Chạy ô sau đây và nhập thông tin đăng nhập Kaggle của bạn khi được yêu cầu.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Một cách khác là đặt KAGGLE_USERNAME và KAGGLE_KEY trong môi trường của bạn nếu kagglehub.login() không phù hợp với bạn.

Cài đặt

Cài đặt Keras và KerasNLP bằng mô hình Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Thiết lập phần phụ trợ Keras JAX

Nhập JAX và chạy quy trình kiểm tra tính hợp lý trên TPU. Kaggle cung cấp các thiết bị TPUv3-8 có 8 nhân TPU với bộ nhớ 16GB mỗi nhân.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Mô hình tải

import keras
import keras_nlp

Lưu ý về chương trình đào tạo độ chính xác hỗn hợp trên GPU NVIDIA

Khi huấn luyện trên GPU NVIDIA, độ chính xác kết hợp (keras.mixed_precision.set_global_policy('mixed_bfloat16')) có thể được dùng để tăng tốc độ tập luyện mà không gây ảnh hưởng lớn đến chất lượng đào tạo. Trong hầu hết các trường hợp, bạn nên bật chế độ độ chính xác hỗn hợp vì tính năng này sẽ giúp tiết kiệm cả bộ nhớ và thời gian. Tuy nhiên, hãy lưu ý rằng ở kích thước lô nhỏ, nó có thể làm tăng mức sử dụng bộ nhớ lên 1,5 lần (trọng số sẽ được tải hai lần, với độ bán chính xác và độ chính xác đầy đủ).

Để suy luận, độ bán chính xác (keras.config.set_floatx("bfloat16")) sẽ hoạt động và tiết kiệm bộ nhớ, trong khi không thể áp dụng độ chính xác hỗn hợp.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Để tải mô hình bằng các trọng số và tensor được phân phối trên các TPU, trước tiên, hãy tạo một DeviceMesh mới. DeviceMesh đại diện cho một tập hợp các thiết bị phần cứng được định cấu hình để tính toán phân phối và được ra mắt trong Keras 3 dưới dạng một phần của API phân phối hợp nhất.

API phân phối hỗ trợ tính năng tải song song dữ liệu và mô hình, giúp mở rộng quy mô một cách hiệu quả các mô hình học sâu trên nhiều trình tăng tốc và máy chủ lưu trữ. Công cụ này tận dụng khung cơ bản (ví dụ: JAX) để phân phối chương trình và các tensor theo lệnh phân đoạn thông qua một quy trình được gọi là mở rộng chương trình đơn, nhiều dữ liệu (SPMD). Hãy xem thêm thông tin chi tiết trong hướng dẫn mới về API phân phối Kers 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap từ API phân phối chỉ định cách phân đoạn hoặc sao chép trọng số và tensor bằng cách sử dụng các khoá chuỗi (ví dụ: token_embedding/embeddings bên dưới) (được xử lý như biểu thức chính quy để khớp với các đường dẫn tensor). Các tensor phù hợp được phân đoạn với kích thước mô hình (8 TPU); những thiết bị khác sẽ được sao chép hoàn toàn.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel cho phép bạn phân đoạn trọng số mô hình hoặc tensor kích hoạt trên mọi thiết bị trên DeviceMesh. Trong trường hợp này, một số trọng số của mô hình Gemma 7B được phân đoạn trên 8 chip TPU theo layout_map được xác định ở trên. Bây giờ, hãy tải mô hình theo cách phân phối.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Bây giờ, hãy xác minh rằng mô hình đã được phân vùng chính xác. Hãy lấy decoder_block_1 làm ví dụ.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Suy luận trước khi tinh chỉnh

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Người mẫu tạo ra một danh sách các bộ phim hài đặc sắc từ thập niên 90 để xem. Bây giờ, chúng ta sẽ tinh chỉnh mô hình Gemma để thay đổi kiểu đầu ra.

Tinh chỉnh bằng IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Tinh chỉnh bằng cách sử dụng tính năng Điều chỉnh thứ hạng thấp (LoRA). LoRA là một kỹ thuật tinh chỉnh giúp giảm đáng kể số lượng tham số có thể huấn luyện cho các tác vụ hạ nguồn bằng cách đóng băng toàn bộ trọng số của mô hình và chèn một số lượng nhỏ hơn trọng số mới có thể huấn luyện vào mô hình. Về cơ bản, LoRA tái tham số các ma trận có trọng số đầy đủ lớn hơn bằng 2 ma trận hạng thấp hơn AxB để huấn luyện, và kỹ thuật này giúp quá trình huấn luyện nhanh hơn và tiết kiệm bộ nhớ hơn nhiều.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Lưu ý rằng việc bật LoRA sẽ làm giảm đáng kể số lượng thông số có thể huấn luyện, từ 7 tỷ xuống chỉ còn 11 triệu.

Suy luận sau khi tinh chỉnh

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Sau khi tinh chỉnh, mô hình này đã học được phong cách đánh giá phim và hiện đang tạo ra kết quả theo phong cách đó trong bối cảnh phim hài những năm 90.

Các bước tiếp theo

Trong hướng dẫn này, bạn đã tìm hiểu cách sử dụng phần phụ trợ KerasNLP JAX để tinh chỉnh mô hình Gemma trên tập dữ liệu IMDb theo cách phân tán trên các TPU mạnh mẽ. Sau đây là một số đề xuất về những điều khác bạn nên tìm hiểu: