ai.google.dev で表示 | Google Colab で実行 | Kaggle で実行する | Vertex AI で開く | GitHub でソースを見る |
概要
Gemma は、Google Gemini モデルの作成に使用された研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。Gemma は特定のニーズに合わせてさらに微調整できます。しかし、Gemma などの大規模言語モデルはサイズが非常に大きく、一部はファインチューニング用の単一アクセラレータに収まらない可能性があります。この場合、ファインチューニングには一般に次の 2 つの方法があります。
- パラメータ エフィシエント ファインチューニング(PEFT): 忠実度をある程度犠牲にして、有効なモデルサイズを縮小します。LoRA はこのカテゴリに該当します。LoRA を使用して Keras で Gemma モデルをファインチューニングするチュートリアルでは、単一の GPU で KerasNLP を使用して LoRA で Gemma 2B モデル
gemma_2b_en
をファインチューニングする方法について説明します。 - モデルの並列処理による完全なパラメータ ファインチューニング。モデル並列処理では、1 つのモデルの重みが複数のデバイスに分散され、水平スケーリングが可能になります。分散トレーニングの詳細については、Keras ガイドをご覧ください。
このチュートリアルでは、JAX バックエンドを使用して Keras を使用し、Google の Tensor Processing Unit(TPU)で LoRA とモデル並列分散トレーニングを使用して Gemma 7B モデルをファインチューニングする方法について説明します。このチュートリアルでは、LoRA をオフにして、時間はかかりますが、より正確な全パラメータ チューニングを行うことができます。
アクセラレータの使用
技術的には、このチュートリアルでは TPU または GPU のいずれかを使用できます。
TPU 環境に関する注意事項
Google には、TPU を提供する 3 つのプロダクトがあります。
- Colab では TPU v2 が無料で提供されており、このチュートリアルでは十分です。
- Kaggle は TPU v3 を無料で提供しており、このチュートリアルでも機能します。
- Cloud TPU は TPU v3 以降の世代を提供します。設定方法の 1 つは次のとおりです。
マルチ GPU の設定に関する注意事項
このチュートリアルでは TPU のユースケースに焦点を当てていますが、マルチ GPU マシンをお持ちの場合は、独自のニーズに合わせて簡単に調整できます。
Colab で作業する場合は、Colab Connect メニューの [カスタム GCE VM に接続します] から、Colab 用にマルチ GPU VM を直接プロビジョニングすることもできます。
ここでは、Kaggle の無料 TPU の使用に焦点を当てます。
始める前に
Kaggle 認証情報
Gemma モデルは Kaggle によってホストされます。Gemma を使用するには、Kaggle でアクセスをリクエストします。
- kaggle.com にログインまたは登録する
- Gemma モデルカードを開き、[Request Access] を選択します。
- 同意フォームに入力し、利用規約に同意する
次に、Kaggle API を使用するには、API トークンを作成します。
- Kaggle の設定を開きます。
- [Create New Token] を選択します。
kaggle.json
ファイルがダウンロードされます。Kaggle の認証情報が含まれています
次のセルを実行し、Kaggle の認証情報の入力を求められたら入力します。
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
kagglehub.login() が機能しない場合は、環境に KAGGLE_USERNAME と KAGGLE_KEY を設定することもできます。
インストール
Gemma モデルを使用して Keras と KerasNLP をインストールします。
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
Keras JAX バックエンドを設定する
JAX をインポートし、TPU で健全性チェックを実行します。Kaggle は、それぞれ 16 GB のメモリを備えた 8 個の TPU コアを備えた TPUv3-8 デバイスを提供しています。
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
モデルを読み込む
import keras
import keras_nlp
NVIDIA GPU での混合精度トレーニングに関する注意事項
NVIDIA GPU でトレーニングする場合は、混合精度(keras.mixed_precision.set_global_policy('mixed_bfloat16')
)を使用して、トレーニングの品質に最小限の影響を与えてトレーニングを高速化できます。ほとんどの場合、メモリと時間の両方を節約するため、混合精度をオンにすることをおすすめします。ただし、バッチサイズが小さいと、メモリ使用量が 1.5 倍になる可能性があります(重みは半精度と完全精度の 2 回読み込まれます)。
推論では、混合精度は適用されませんが、半精度(keras.config.set_floatx("bfloat16")
)が機能し、メモリを節約できます。
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
TPU に分散された重みとテンソルを使用してモデルを読み込むには、まず新しい DeviceMesh
を作成します。DeviceMesh
は、分散コンピューティング用に構成されたハードウェア デバイスのコレクションを表します。これは、統合された分散 API の一部として Keras 3 で導入されました。
distribution API を使用すると、データとモデルの並列処理が可能になり、複数のアクセラレータとホストでディープ ラーニング モデルを効率的にスケーリングできます。基盤となるフレームワーク(JAX など)を活用して、単一プログラム、複数データ(SPMD)拡張と呼ばれる手順で、シャーディング ディレクティブに従ってプログラムとテンソルを分散します。詳しくは、新しい Keras 3 ディストリビューション API ガイドをご覧ください。
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
distribution API の LayoutMap
では、文字列キーを使用して、重みとテンソルをシャーディングまたは複製する方法(下の token_embedding/embeddings
など)を指定します。この文字列キーは、テンソルパスと一致するように正規表現として扱われます。一致したテンソルはモデルの次元(8 個の TPU)でシャーディングされ、その他は完全に複製されます。
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
ModelParallel
を使用すると、DeviceMesh
のすべての偏差にわたってモデルの重みまたは活性化テンソルをシャーディングできます。この場合、Gemma 7B モデルの重みの一部は、上記で定義した layout_map
に従って 8 つの TPU チップにシャーディングされます。次に、分散方式でモデルを読み込みます。
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
次に、モデルが正しくパーティショニングされていることを確認します。decoder_block_1
を例に説明します。
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
ファインチューニング前の推論
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
このモデルは、90 年代のおすすめのコメディー映画のリストを生成します。次に、Gemma モデルを微調整して出力スタイルを変更します。
IMDB でファインチューニングする
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
Low Rank Adaptation(LoRA)を使用してファインチューニングを行います。LoRA はファインチューニング手法の一つで、モデルのすべての重みを凍結し、少数の新しいトレーニング可能な重みをモデルに挿入することで、ダウンストリームのタスク用にトレーニング可能なパラメータの数を大幅に削減します。基本的に LoRA は、大きな全重み行列を 2 つの小さな低ランク行列 AxB で再パラメータ化してトレーニングします。この手法により、トレーニングが大幅に高速化され、メモリ効率が向上します。
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
LoRA を有効にすると、トレーニング可能なパラメータの数が 70 億からわずか 1, 100 万に大幅に削減されます。
ファインチューニング後の推論
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
ファインチューニングの後、モデルは映画レビューのスタイルを学習し、90 年代のコメディ映画のコンテキストでそのスタイルの出力を生成しています。
次のステップ
このチュートリアルでは、KerasNLP JAX バックエンドを使用して、強力な TPU で分散処理で IMDb データセットの Gemma モデルをファインチューニングする方法について学習しました。その他の学習については、以下を参考にしてください。
- Keras Gemma を使ってみる方法を学びます。
- GPU で Gemma モデルをファインチューニングする方法を学習する。