Keras を使用した Gemma による分散チューニング

ai.google.dev で表示 Kaggle で実行 Vertex AI で開く GitHub でソースを表示

概要

Gemma は、Google Gemini モデルの作成に使用された研究とテクノロジーから構築された、軽量で最先端のオープンモデルのファミリーです。Gemma は特定のニーズに合わせてさらに微調整できます。ただし、Gemma などの大規模言語モデルはサイズが非常に大きくなることがあり、ファインチューニングには sing Accelerator に適さないものもあります。この場合、ファインチューニングには 2 つの一般的なアプローチがあります。

  1. パラメータ エフィシエント ファインチューニング(PEFT)は、忠実度を犠牲にして、有効なモデルサイズを縮小しようとします。LoRA はこのカテゴリに分類されます。LoRA を使用して Keras で Gemma モデルを微調整するのチュートリアルでは、単一 GPU で KerasNLP を使用し、LoRA で Gemma 2B モデル gemma_2b_en を微調整する方法を説明しています。
  2. モデル並列処理による完全なパラメータ ファインチューニング。モデル並列処理では、1 つのモデルの重みを複数のデバイスに分散し、水平方向のスケーリングを可能にします。分散トレーニングの詳細については、こちらの Keras ガイドをご覧ください。

このチュートリアルでは、Keras と JAX バックエンドを使用して、LoRA で Gemma 7B モデルを微調整し、Google の Tensor Processing Unit(TPU)でモデルパラリズム分散トレーニングを行う方法について説明します。このチュートリアルでは LoRA をオフにすると、パラメータを完全にチューニングする際に時間をかけてより正確に調整できます。

アクセラレータの使用

このチュートリアルでは、技術的には TPU または GPU を使用できます。

TPU 環境に関する注意事項

Google では、TPU を提供する次の 3 つのプロダクトを提供しています。

  • Colab で提供される TPU v2 では、このチュートリアルでは不十分です。
  • Kaggle は TPU v3 を無料で提供しており、このチュートリアルに適しています。
  • Cloud TPU は TPU v3 以降の世代を提供します。たとえば、次のような設定が可能です。
    1. 新しい TPU VM を作成します。
    2. 目的の Jupyter サーバーポートに SSH ポート転送を設定します。
    3. Jupyter をインストールして TPU VM で起動してから、「ローカル ランタイムに接続する」で Colab に接続します。

マルチ GPU 設定に関する注意事項

このチュートリアルでは TPU のユースケースに焦点を当てていますが、マルチ GPU マシンをお持ちの場合は、独自のニーズに合わせて簡単に調整できます。

Colab で作業する場合は、Colab の接続メニューの [カスタム GCE VM に接続] から直接、Colab 用のマルチ GPU VM をプロビジョニングすることもできます。

ここでは、Kaggle の無料 TPU の使用に焦点を当てます。

始める前に

Kaggle 認証情報

Gemma モデルは Kaggle でホストされています。Gemma を使用するには、Kaggle でアクセス権をリクエストしてください。

  • kaggle.com でログインまたは登録
  • Gemma モデルカードを開き、[Request Access] を選択します。
  • 同意フォームに記入し、利用規約に同意します

次に、Kaggle API を使用するために API トークンを作成します。

  • Kaggle の設定を開きます。
  • [Create New Token] を選択します。
  • kaggle.json ファイルがダウンロードされます。Kaggle 認証情報が含まれています。

次のセルを実行し、プロンプトが表示されたら Kaggle 認証情報を入力します。

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

kagglehub.login() が機能しない場合は、環境で KAGGLE_USERNAME と KAGGLE_KEY を設定する方法もあります。

インストール

Gemma モデルを使用して Keras と KerasNLP をインストールする。

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX バックエンドを設定する

JAX をインポートし、TPU でサニティ チェックを実行します。Kaggle は、それぞれ 16 GB のメモリを備えた 8 個の TPU コアを搭載した TPUv3-8 デバイスを提供しています。

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

モデルを読み込む

import keras
import keras_nlp

NVIDIA GPU での混合精度トレーニングに関する注意事項

NVIDIA GPU でトレーニングする場合は、混合精度(keras.mixed_precision.set_global_policy('mixed_bfloat16'))を使用すると、トレーニングの品質への影響を最小限に抑えながらトレーニングを高速化できます。ほとんどの場合、メモリと時間の両方を節約するため、混合精度を有効にすることをおすすめします。ただし、小さいバッチサイズではメモリ使用量が 1.5 倍に増える可能性があるので注意してください(重みは半精度とフル精度で 2 回読み込まれます)。

推論では半精度(keras.config.set_floatx("bfloat16"))が機能し、メモリを節約できますが、混合精度は適用できません。

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

重みとテンソルを TPU に分散した状態でモデルを読み込むには、まず新しい DeviceMesh を作成します。DeviceMesh は、分散計算用に構成されたハードウェア デバイスの集合であり、Unified Distribution API の一部として Keras 3 で導入されました。

ディストリビューション API を使用すると、データとモデルの並列処理が可能になり、複数のアクセラレータやホスト上でディープ ラーニング モデルを効率的にスケーリングできます。基盤となるフレームワーク(JAX など)を活用し、SPMD(Single Program, Multiple Data)展開と呼ばれる手順を通じて、シャーディング ディレクティブに従ってプログラムとテンソルを分散します。詳細については、新しい Keras 3 ディストリビューション API ガイドをご覧ください。

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

ディストリビューション API の LayoutMap は、テンソルパスと一致する正規表現のように扱われる文字列キー(たとえば、以下の token_embedding/embeddings)を使用して、重みとテンソルをシャーディングまたは複製する方法を指定します。一致したテンソルはモデルの次元(8 TPU)でシャーディングされ、その他は完全に複製されます。

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel を使用すると、DeviceMesh のすべてのデバイスにわたってモデルの重みまたは活性化テンソルをシャーディングできます。この場合、Gemma 7B モデルの重みの一部は、上で定義した layout_map に従って 8 個の TPU チップにシャーディングされます。次に、分散された方法でモデルを読み込みます。

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

モデルが正しくパーティショニングされていることを確認します。decoder_block_1 を例にします。

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

ファインチューニング前の推論

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

モデルは、90 年代の人気のコメディ映画のリストを生成します。次に、Gemma モデルを微調整して出力スタイルを変更します。

IMDB によるファインチューニング

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

低ランク適応(LoRA)を使用して微調整を行います。LoRA はファインチューニング手法であり、モデルのすべての重みを固定し、少数の新しいトレーニング可能な重みをモデルに挿入することで、下流のタスクのトレーニング可能なパラメータの数を大幅に削減します。基本的に LoRA は、大きなフル重み行列を 2 つの小さな低ランク行列 AxB で再パラメータ化してトレーニングします。この手法により、トレーニングが大幅に高速化され、メモリ効率が向上します。

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

LoRA を有効にすると、トレーニング可能なパラメータの数が 70 億からわずか 1, 100 万に大幅に削減されます。

ファインチューニング後の推論

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

ファインチューニング後、モデルは映画レビューのスタイルを学習し、今度は 90 年代のコメディ映画のコンテキストで、そのスタイルで出力を生成しています。

次のステップ

このチュートリアルでは、KerasNLP JAX バックエンドを使用して、強力な TPU 上で分散方式で IMDb データセットの Gemma モデルを微調整する方法を学習しました。その他の情報について、いくつかヒントをご紹介します。