使用 Keras 使用 Gemma 進行分散式調整

前往 ai.google.dev 查看 在 Google Colab 中執行 在 Kaggle 中執行 在 Vertex AI 中開啟 前往 GitHub 查看原始碼

總覽

Gemma 是一組最先進的開放式模型,以研究成果與技術打造 Google Gemini 模型。Gemma 可根據特定需求進一步微調。不過,Gemma 等大型語言模型的大小可能非常龐大,其中部分模型可能無法用於微調加速器。在這種情況下,有兩種一般方法可以微調:

  1. 參數高效率微調 (PEFT),藉此犧牲一些保真度來縮減有效模型大小。LoRA 屬於這個類別,並示範在「使用 LoRA 在 Keras 中微調 Gemma 模型」教學課程中的示範,示範如何在單一 GPU 上使用 KerasNLP,以 LoRA 微調 Gemma 2B 模型 gemma_2b_en
  2. 使用模型平行處理功能微調完整參數。模型平行處理會將單一模型的權重分配至多部裝置,並可水平調度資源。如要進一步瞭解分散式訓練,請參閱這份 Keras 指南

本教學課程將逐步說明如何搭配 JAX 後端使用 Keras,以使用 LoRA 微調 Gemma 7B 模型,並在 Google 的 Tensor Processing Unit (TPU) 上進行模型癱瘓分散式訓練。請注意,您可以停用本教學課程中的 LoRA,以降低速度但更準確的完整參數調整作業。

使用加速器

就技術層面而言,您可以為本教學課程使用 TPU 或 GPU。

TPU 環境注意事項

Google 有 3 項提供 TPU 的產品:

  • Colab 免費提供 TPU v2,足以執行本教學課程。
  • Kaggle 免費提供 TPU v3,而這個 API 也能用於本教學課程。
  • Cloud TPU 提供 TPU v3 及更新版本。設定的方法包括:
    1. 建立新的 TPU VM
    2. 為所需的 Jupyter 伺服器通訊埠設定 SSH 通訊埠轉送
    3. 安裝 Jupyter 並在 TPU VM 上啟動 Jupyter,然後透過「連線至本機執行階段」連線至 Colab

多 GPU 設定注意事項

雖然本教學課程的重點在於 TPU 的用途,但如果您有多 GPU 機器,也能根據自身需求輕鬆調整。

如果你偏好使用 Colab,也可以透過「連線至自訂 GCE VM」直接為 Colab 佈建多 GPU VM。開啟「Colab Connect」選單

我們會著重於使用 Kaggle 的免費 TPU

事前準備

Kaggle 憑證

Gemma 模型是由 Kaggle 代管。如要使用 Gemma,請在 Kaggle 上要求存取權:

接著,如要使用 Kaggle API,請建立 API 權杖:

  • 開啟 Kaggle 設定
  • 選取「Create New Token」
  • 系統會下載一個 kaggle.json 檔案。其中含有您的 Kaggle 憑證

請執行以下儲存格,並在系統詢問時輸入您的 Kaggle 憑證。

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

或是如果 kagglehub.login() 無法幫助您,在環境中設定 KAGGLE_USERNAME 和 KAGGLE_KEY。

安裝

使用 Gemma 模型安裝 Keras 和 KerasNLP。

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

設定 Keras JAX 後端

匯入 JAX 並在 TPU 上執行例行檢查。Kaggle 提供 TPUv3-8 裝置,這些裝置包含 8 個 TPU 核心,每個核心具備 16 GB 的記憶體。

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

載入模型

import keras
import keras_nlp

使用 NVIDIA GPU 混合精確度訓練的注意事項

在 NVIDIA GPU 上訓練時,您可以使用混合精確度 (keras.mixed_precision.set_global_policy('mixed_bfloat16')) 來加快訓練速度,將對訓練品質的影響降到最低。在大多數情況下,建議您開啟混合精確度,因為這樣可節省記憶體和時間。不過請注意,在小批大小下,記憶體用量可能會加載 1.5 倍 (權重會載入兩次,同時是半精度和完整精確度)。

為了推論,半精度 (keras.config.set_floatx("bfloat16")) 可正常運作,同時節省混合精確度。

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

如要以各個 TPU 分配的權重和張量載入模型,請先建立新的 DeviceMeshDeviceMesh 代表一系列針對分散式運算設定的硬體裝置,這類裝置是在 Keras 3 中導入,做為統合發行 API 的一部分。

分散式 API 可實現資料和模型平行處理,有效率地在多個加速器和主機上擴充深度學習模型。這個平台會利用基礎架構 (例如 JAX),透過名為多個資料 (SPMD) 的單一程式,根據資料分割指令來發布程式和張量。詳情請參閱新版 Keras 3 發行 API 指南

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

發行 API 的 LayoutMap 可指定分割或複製權重和張量的方式,方法是使用字串鍵 (例如下方的 token_embedding/embeddings),系統會將這些鍵視為比對張量路徑的規則運算式。相符的張量會以模型維度 (8 個 TPU) 進行資料分割;其餘模型會完全複製

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel 可讓您針對 DeviceMesh 的所有開發,分割模型權重或啟用張量。在本例中,部分 Gemma 7B 模型權重是根據上述定義的 layout_map,分割在 8 個 TPU 晶片。現在,以分散的方式載入模型。

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

現在,請確認模型已正確分區。以 decoder_block_1 為例。

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

在微調之前推論

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

這個模型產生了 90 年代的精彩喜劇片清單。現在,我們要微調 Gemma 模型來變更輸出樣式。

使用 IMDB 微調

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

使用低評級調整 (LoRA) 執行微調。LoRA 是一種微調技術,可以凍結模型的完整權重,並將少量的新可訓練權重插入模型,藉此大幅減少下游工作的可訓練參數數量。基本上,LoRA 將較大的完整權重矩陣將較大的全權重矩陣,乘以 2 個較小的低階矩陣 AxB 進行訓練,而這項技術能使訓練速度更快,記憶體效率也更高。

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

請注意,啟用 LoRA 可大幅減少可訓練參數的數量,從 70 億次減少到只有 1, 100 萬參數。

微調後的推論

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

模型經過微調後,已瞭解電影評論的風格,現在於 90 年代喜劇片中產生該風格的輸出內容。

後續步驟

在這個教學課程中,您已瞭解如何使用 KerasNLP JAX 後端,在強大的 TPU 上以分散式方式微調 IMDb 資料集上的 Gemma 模型。如果還有以下建議,歡迎參考: