Điều chỉnh phân phối với Gemma bằng Keras

Xem trên ai.google.dev Chạy trong Google Colab Chạy ở Kaggle Mở trong Vertex AI Xem nguồn trên GitHub

Tổng quan

Gemma là một bộ mô hình mở, hiện đại và gọn nhẹ được xây dựng từ nghiên cứu và công nghệ dùng để tạo mô hình Google Gemini. Gemma có thể được tinh chỉnh thêm để phù hợp với nhu cầu cụ thể. Tuy nhiên, Mô hình ngôn ngữ lớn (chẳng hạn như Gemma) có thể có kích thước rất lớn và một số mô hình có thể không phù hợp với một trình tăng tốc đơn để tinh chỉnh. Trong trường hợp này, có hai phương pháp chung để tinh chỉnh chúng:

  1. Điều chỉnh tinh vi hiệu quả theo tham số (PEFT), nhằm thu nhỏ kích thước mô hình hiệu quả bằng cách hy sinh một số độ trung thực. LoRA thuộc danh mục này và hướng dẫn Chỉnh sửa mô hình Gemma trong Keras bằng LoRA minh hoạ cách tinh chỉnh mô hình Gemma 2B gemma_2b_en bằng LoRA bằng KerasNLP trên một GPU.
  2. Tinh chỉnh tham số đầy đủ bằng tính năng song song mô hình. Tính năng song song của mô hình phân phối trọng số của một mô hình trên nhiều thiết bị và cho phép mở rộng theo chiều ngang. Bạn có thể tìm hiểu thêm về tính năng huấn luyện phân tán trong hướng dẫn về Keras này.

Hướng dẫn này sẽ hướng dẫn bạn cách sử dụng Keras với phần phụ trợ JAX để tinh chỉnh mô hình Gemma 7B bằng LoRA và đào tạo phân tán song song mô hình trên Bộ xử lý tensor (TPU) của Google. Xin lưu ý rằng bạn có thể tắt LoRA trong hướng dẫn này để điều chỉnh thông số đầy đủ chậm hơn nhưng chính xác hơn.

Sử dụng trình tăng tốc

Về mặt kỹ thuật, bạn có thể sử dụng TPU hoặc GPU cho hướng dẫn này.

Lưu ý về môi trường TPU

Google có 3 sản phẩm cung cấp TPU:

  • Colab cung cấp TPU phiên bản 2 miễn phí, đủ cho hướng dẫn này.
  • Kaggle cung cấp TPU v3 miễn phí và cũng phù hợp với hướng dẫn này.
  • Cloud TPU cung cấp TPU phiên bản 3 trở lên. Một cách để thiết lập tính năng này là:
    1. Tạo máy ảo TPU mới
    2. Thiết lập tính năng chuyển tiếp cổng SSH cho cổng máy chủ Jupyter mà bạn dự định sử dụng
    3. Cài đặt và khởi động Jupyter trên máy ảo TPU, sau đó kết nối với Colab thông qua tính năng "Kết nối với môi trường thời gian chạy cục bộ"

Lưu ý về cách thiết lập nhiều GPU

Mặc dù hướng dẫn này tập trung vào trường hợp sử dụng TPU, nhưng bạn có thể dễ dàng điều chỉnh hướng dẫn này cho nhu cầu của riêng mình nếu có máy nhiều GPU.

Nếu muốn làm việc thông qua Colab, bạn cũng có thể trực tiếp cấp phép máy ảo nhiều GPU cho Colab thông qua mục "Kết nối với một máy ảo GCE tuỳ chỉnh" trong trình đơn Colab Connect.

Chúng ta sẽ tập trung vào việc sử dụng TPU miễn phí của Kaggle tại đây.

Trước khi bắt đầu

Thông tin xác thực Kaggle

Các mô hình Gemma do Kaggle lưu trữ. Để sử dụng Gemma, hãy yêu cầu quyền truy cập trên Kaggle:

  • Đăng nhập hoặc đăng ký tại kaggle.com
  • Mở thẻ mô hình Gemma rồi chọn "Yêu cầu quyền truy cập"
  • Hoàn tất biểu mẫu đồng ý rồi chấp nhận các điều khoản và điều kiện

Sau đó, để sử dụng Kaggle API, hãy tạo một mã thông báo API:

  • Mở chế độ cài đặt Kaggle
  • Chọn "Tạo mã thông báo mới"
  • Tệp kaggle.json được tải xuống. Tệp này chứa thông tin xác thực của bạn trên Kaggle

Chạy ô sau và nhập thông tin xác thực Kaggle của bạn khi được nhắc.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Một cách khác là đặt KAGGLE_USERNAME và KAGGLE_KEY trong môi trường của bạn nếu kagglehub.login() không hoạt động.

Cài đặt

Cài đặt Keras và KerasNLP bằng mô hình Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Thiết lập phần phụ trợ Keras JAX

Nhập JAX và chạy quy trình kiểm tra tính hợp lý trên TPU. Kaggle cung cấp các thiết bị TPUv3-8 có 8 lõi TPU với mỗi lõi có 16 GB bộ nhớ.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Mô hình tải

import keras
import keras_nlp

Ghi chú về việc huấn luyện độ chính xác kết hợp trên GPU NVIDIA

Khi huấn luyện trên GPU NVIDIA, bạn có thể sử dụng độ chính xác kết hợp (keras.mixed_precision.set_global_policy('mixed_bfloat16')) để tăng tốc độ huấn luyện mà không ảnh hưởng nhiều đến chất lượng huấn luyện. Trong hầu hết các trường hợp, bạn nên bật chế độ độ chính xác hỗn hợp vì tính năng này sẽ giúp tiết kiệm cả bộ nhớ và thời gian. Tuy nhiên, hãy lưu ý rằng ở kích thước lô nhỏ, việc này có thể làm tăng mức sử dụng bộ nhớ lên 1,5 lần (trọng số sẽ được tải hai lần, ở độ bán chính xác và độ chính xác đầy đủ).

Đối với suy luận, độ bán chính xác (keras.config.set_floatx("bfloat16")) sẽ hoạt động và tiết kiệm bộ nhớ trong khi độ chính xác kết hợp không áp dụng được.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Để tải mô hình bằng các trọng số và tensor được phân phối trên các TPU, trước tiên, hãy tạo một DeviceMesh mới. DeviceMesh đại diện cho một tập hợp các thiết bị phần cứng được định cấu hình để tính toán phân tán và được giới thiệu trong Keras 3 như một phần của API phân phối hợp nhất.

API phân phối cho phép dữ liệu và mô hình song song, cho phép mở rộng quy mô hiệu quả các mô hình học sâu trên nhiều trình tăng tốc và máy chủ. Công cụ này tận dụng khung cơ bản (ví dụ: JAX) để phân phối chương trình và các tensor theo lệnh phân đoạn thông qua một quy trình được gọi là mở rộng chương trình đơn, nhiều dữ liệu (SPMD). Hãy xem thêm thông tin chi tiết trong hướng dẫn mới về API phân phối Kers 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap từ API phân phối chỉ định cách phân đoạn hoặc sao chép trọng số và tensor bằng cách sử dụng các khoá chuỗi (ví dụ: token_embedding/embeddings bên dưới) (được xử lý như biểu thức chính quy để khớp với các đường dẫn tensor). Các tensor được so khớp được phân đoạn theo các phương diện mô hình (8 TPU); các tensor khác sẽ được sao chép đầy đủ.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel cho phép bạn phân đoạn trọng số mô hình hoặc tensor kích hoạt trên tất cả các thiết bị trên DeviceMesh. Trong trường hợp này, một số trọng số mô hình Gemma 7B được phân đoạn trên 8 khối TPU theo layout_map được xác định ở trên. Bây giờ, hãy tải mô hình theo cách phân phối.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Bây giờ, hãy xác minh rằng mô hình đã được phân vùng chính xác. Hãy lấy decoder_block_1 làm ví dụ.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Suy luận trước khi tinh chỉnh

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Người mẫu tạo ra một danh sách các bộ phim hài đặc sắc từ thập niên 90 để xem. Bây giờ, chúng ta sẽ tinh chỉnh mô hình Gemma để thay đổi kiểu đầu ra.

Tinh chỉnh bằng IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Tinh chỉnh bằng cách sử dụng tính năng Điều chỉnh thứ hạng thấp (LoRA). LoRA là một kỹ thuật tinh chỉnh giúp giảm đáng kể số lượng tham số có thể huấn luyện cho các tác vụ ở hạ nguồn bằng cách đóng băng toàn bộ trọng số của mô hình và chèn một số lượng nhỏ trọng số mới có thể huấn luyện vào mô hình. Về cơ bản, LoRA định lại tham số cho các ma trận trọng số đầy đủ lớn hơn bằng 2 ma trận AxB có thứ hạng thấp hơn để huấn luyện và kỹ thuật này giúp quá trình huấn luyện nhanh hơn và hiệu quả hơn về bộ nhớ.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Lưu ý rằng việc bật LoRA sẽ làm giảm đáng kể số lượng thông số có thể huấn luyện, từ 7 tỷ xuống chỉ còn 11 triệu.

Suy luận sau khi tinh chỉnh

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Sau khi tinh chỉnh, mô hình đã học được phong cách của bài đánh giá phim và hiện đang tạo ra kết quả theo phong cách đó trong bối cảnh của các bộ phim hài thập niên 90.

Bước tiếp theo

Trong hướng dẫn này, bạn đã tìm hiểu cách sử dụng phần phụ trợ KerasNLP JAX để tinh chỉnh mô hình Gemma trên tập dữ liệu IMDb theo cách phân tán trên các TPU mạnh mẽ. Dưới đây là một số đề xuất về những kiến thức khác mà bạn nên tìm hiểu: