Điều chỉnh phân phối với Gemma bằng Keras

Xem trên ai.google.dev Chạy trong Kaggle Mở trong Vertex AI Xem nguồn trên GitHub

Tổng quan

Gemma là một dòng mô hình mở, hiện đại và gọn nhẹ, được xây dựng từ quá trình nghiên cứu và công nghệ dùng để tạo ra các mô hình Google Gemini. Gemma có thể được tinh chỉnh thêm cho phù hợp với nhu cầu cụ thể. Tuy nhiên, các Mô hình ngôn ngữ lớn, chẳng hạn như Gemma, có thể có kích thước rất lớn và một số mô hình có thể không vừa với bộ tăng tốc hát để tinh chỉnh. Trong trường hợp này, có hai phương pháp chung để tinh chỉnh chúng:

  1. Hiệu quả tinh chỉnh thông số (PEFT), tìm cách giảm kích thước mô hình hiệu quả bằng cách giảm đi một chút độ trung thực. LoRA thuộc danh mục này và hướng dẫn Tinh chỉnh các mô hình Gemma trong Keras bằng LoRA minh hoạ cách tinh chỉnh mô hình Gemma 2B gemma_2b_en với LoRA bằng KerasNLP trên một GPU.
  2. Tinh chỉnh tham số đầy đủ bằng tính năng song song của mô hình. Tính năng tải song song của mô hình sẽ phân phối trọng số của một mô hình trên nhiều thiết bị và cho phép điều chỉnh theo tỷ lệ theo chiều ngang. Bạn có thể tìm hiểu thêm về chương trình đào tạo phân tán trong hướng dẫn này của Kers.

Hướng dẫn này sẽ hướng dẫn bạn cách sử dụng Keras cùng với phần phụ trợ JAX để tinh chỉnh mô hình Gemma 7B với LoRA và huấn luyện phân phối mô hình-parallism trên Bộ xử lý Tensor (TPU) của Google. Xin lưu ý rằng bạn có thể tắt LoRA trong hướng dẫn này để điều chỉnh toàn bộ thông số chậm hơn nhưng chính xác hơn.

Sử dụng trình tăng tốc

Về mặt kỹ thuật, bạn có thể sử dụng TPU hoặc GPU cho hướng dẫn này.

Ghi chú trên môi trường TPU

Google có 3 sản phẩm cung cấp TPU:

  • Colab cung cấp TPU v2 nhưng chưa đủ cho hướng dẫn này.
  • Kaggle cung cấp TPU phiên bản 3 miễn phí và bạn có thể sử dụng TPU phiên bản này cho hướng dẫn này.
  • Cloud TPU cung cấp TPU phiên bản 3 và các thế hệ mới hơn. Bạn có thể thiết lập theo cách sau:
    1. Tạo một máy ảo TPU mới
    2. Thiết lập tính năng chuyển tiếp cổng SSH cho cổng máy chủ Jupyter mà bạn dự định
    3. Cài đặt Jupyter rồi khởi động trên máy ảo TPU, sau đó kết nối với Colab thông qua tuỳ chọn "Kết nối với một môi trường thời gian chạy cục bộ"

Lưu ý khi thiết lập nhiều GPU

Mặc dù hướng dẫn này tập trung vào trường hợp sử dụng TPU, nhưng bạn có thể dễ dàng điều chỉnh cho phù hợp với nhu cầu của mình nếu có máy nhiều GPU.

Nếu muốn làm việc thông qua Colab, bạn cũng có thể cấp trực tiếp một máy ảo nhiều GPU cho Colab thông qua mục "Kết nối với một máy ảo GCE tuỳ chỉnh" trong trình đơn của Colab Connect.

Chúng ta sẽ tập trung vào việc sử dụng TPU miễn phí của Kaggle tại đây.

Trước khi bắt đầu

Thông tin đăng nhập Kaggle

Các mô hình Gemma do Kaggle lưu trữ. Để sử dụng Gemma, hãy yêu cầu quyền truy cập trên Kaggle:

  • Đăng nhập hoặc đăng ký tại kaggle.com
  • Mở thẻ mô hình Gemma rồi chọn "Yêu cầu quyền truy cập"
  • Hoàn thành biểu mẫu đồng ý, cũng như chấp nhận các điều khoản và điều kiện

Sau đó, để sử dụng API Kaggle, hãy tạo một mã thông báo API:

  • Mở phần cài đặt Kaggle
  • Chọn "Tạo mã thông báo mới"
  • Đã tải một tệp kaggle.json xuống. Tệp này chứa thông tin đăng nhập Kaggle của bạn

Chạy ô sau và nhập thông tin xác thực Kaggle của bạn khi được yêu cầu.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Một cách khác là đặt KAGGLE_USERNAME và KAGGLE_KEY trong môi trường của bạn nếu kagglehub.login() không hoạt động cho bạn.

Cài đặt

Cài đặt Keras và KerasNLP bằng mô hình Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Thiết lập phần phụ trợ Keras JAX

Nhập JAX và chạy quy trình kiểm tra tình trạng hợp lý trên TPU. Kaggle cung cấp các thiết bị TPUv3-8 có 8 lõi TPU với bộ nhớ 16GB mỗi lõi.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Tải mô hình

import keras
import keras_nlp

Lưu ý về việc huấn luyện độ chính xác hỗn hợp trên GPU NVIDIA

Khi huấn luyện trên GPU NVIDIA, bạn có thể dùng độ chính xác hỗn hợp (keras.mixed_precision.set_global_policy('mixed_bfloat16')) để tăng tốc độ huấn luyện mà không gây ảnh hưởng lớn đến chất lượng huấn luyện. Trong hầu hết trường hợp, bạn nên bật độ chính xác hỗn hợp để tiết kiệm cả bộ nhớ và thời gian. Tuy nhiên, hãy lưu ý rằng ở các kích thước lô nhỏ, nó có thể làm tăng mức sử dụng bộ nhớ lên 1,5 lần (trọng số sẽ được tải hai lần, với độ bán chính xác và độ chính xác đầy đủ).

Để dự đoán, độ bán chính xác (keras.config.set_floatx("bfloat16")) sẽ hoạt động và tiết kiệm bộ nhớ, nhưng không áp dụng độ chính xác hỗn hợp.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Để tải mô hình bằng trọng số và tensor được phân phối trên các TPU, trước tiên, hãy tạo một DeviceMesh mới. DeviceMesh đại diện cho tập hợp các thiết bị phần cứng được định cấu hình để tính toán được phân phối. API này đã ra mắt trong Keras 3 dưới dạng một phần của API phân phối hợp nhất.

API phân phối hỗ trợ tính năng tải song song dữ liệu và mô hình, giúp mở rộng quy mô một cách hiệu quả của các mô hình học sâu trên nhiều trình tăng tốc và máy chủ lưu trữ. Lớp này tận dụng khung cơ bản (ví dụ: JAX) để phân phối chương trình và tensor theo lệnh phân đoạn thông qua một quy trình gọi là một chương trình, mở rộng nhiều dữ liệu (SPMD). Hãy xem thêm thông tin chi tiết trong Hướng dẫn về API phân phối Keras 3 mới.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap trong API phân phối chỉ định cách phân đoạn hoặc sao chép trọng số và tensor bằng cách sử dụng các khoá chuỗi (ví dụ: token_embedding/embeddings dưới đây) được coi là biểu thức chính quy để khớp với đường dẫn tensor. Tensor trùng khớp sẽ được phân đoạn theo kích thước mô hình (8 TPU); các tensor khác sẽ được sao chép hoàn toàn.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel cho phép bạn phân đoạn trọng số của mô hình hoặc tensor kích hoạt trên tất cả các thiết bị trên DeviceMesh. Trong trường hợp này, một số trọng số mô hình Gemma 7B được phân đoạn trên 8 chip TPU theo layout_map được xác định ở trên. Bây giờ, hãy tải mô hình theo cách được phân phối.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Bây giờ, hãy kiểm tra để đảm bảo mô hình đã được phân vùng đúng cách. Hãy lấy decoder_block_1 làm ví dụ.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Suy luận trước khi tinh chỉnh

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Mô hình này tạo ra một danh sách những bộ phim hài hay của những năm 90 để bạn xem. Bây giờ, chúng ta tinh chỉnh mô hình Gemma để thay đổi kiểu đầu ra.

Tinh chỉnh bằng IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Thực hiện tinh chỉnh bằng tính năng Điều chỉnh thứ hạng thấp (LoRA). LoRA là một kỹ thuật tinh chỉnh giúp giảm đáng kể số lượng tham số có thể huấn luyện cho các nhiệm vụ hạ nguồn bằng cách đóng băng toàn bộ trọng số của mô hình và chèn một số lượng nhỏ hơn các trọng số mới có thể huấn luyện vào mô hình. Về cơ bản, LoRA tái tham số hoá các ma trận trọng số đầy đủ lớn hơn bằng 2 ma trận bậc thấp nhỏ hơn AxB để huấn luyện và kỹ thuật này giúp việc huấn luyện nhanh hơn và tiết kiệm bộ nhớ hơn.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Lưu ý rằng việc bật LoRA sẽ làm giảm đáng kể số lượng tham số có thể huấn luyện, từ 7 tỷ xuống chỉ còn 11 triệu.

Suy luận sau khi tinh chỉnh

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Sau khi tinh chỉnh, mô hình này đã học được phong cách đánh giá phim và hiện đang tạo ra kết quả theo phong cách đó trong bối cảnh phim hài thập niên 90.

Bước tiếp theo

Trong hướng dẫn này, bạn đã tìm hiểu cách sử dụng phần phụ trợ KerasNLP JAX để tinh chỉnh mô hình Gemma trên tập dữ liệu IMDb theo cách được phân phối trên TPU mạnh mẽ. Sau đây là một vài đề xuất về những điều khác cần tìm hiểu: