Keras で LoRA を使用して Gemma モデルをファインチューニングする

ai.google.dev で表示 Google Colab で実行 Vertex AI で開く GitHub でソースを見る

概要

Gemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。

Gemma などの大規模言語モデル(LLM)は、さまざまな NLP タスクで効果的であることが実証されています。LLM はまず、自己教師あり方式で大規模なテキストのコーパスで事前トレーニングされます。事前トレーニングにより、LLM は単語間の統計的関係などの汎用的な知識を学習できます。LLM は、ドメイン固有のデータでファインチューニングして、ダウンストリームのタスク(感情分析など)を実行できます。

LLM は非常に大規模です(パラメータは数十億単位)。通常のファインチューニング データセットは事前トレーニング データセットよりも比較的はるかに小さいため、ほとんどのアプリケーションでは完全なファインチューニング(モデル内のすべてのパラメータを更新する)は必要ありません。

Low Rank Adaptation(LoRA)は、モデルの重みを凍結し、少数の新しい重みをモデルに挿入することで、ダウンストリーム タスクのトレーニング可能なパラメータの数を大幅に削減する微調整手法です。これにより、LoRA でのトレーニングが大幅に高速化され、メモリ効率が向上します。また、モデル出力の品質を維持しながら、モデルの重み(数百 MB)を小さくできます。

このチュートリアルでは、Databricks Dolly 15k データセットを使用して、KerasNLP を使用して Gemma 2B モデルで LoRA ファインチューニングを行う方法について説明します。このデータセットには、LLM のファインチューニング用に特別に設計された、人間が生成した高品質なプロンプトとレスポンス ペアが 15,000 個含まれています。

セットアップ

Gemma にアクセスする

このチュートリアルを完了するには、まず Gemma の設定にある設定手順を完了する必要があります。Gemma のセットアップ手順では、次の方法について説明しています。

  • kaggle.com で Gemma にアクセスしてください。
  • Gemma 2B モデルを実行するのに十分なリソースがある Colab ランタイムを選択します。
  • Kaggle のユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。

ランタイムの選択

このチュートリアルを完了するには、Gemma モデルを実行するのに十分なリソースを備えた Colab ランタイムが必要です。この場合は、T4 GPU を使用できます。

  1. Colab ウィンドウの右上にある ▾(その他の接続オプション)を選択します。
  2. [ランタイムのタイプを変更] を選択します。
  3. [Hardware accelerator] で [T4 GPU] を選択します。

API キーを設定する

Gemma を使用するには、Kaggle のユーザー名と Kaggle API キーを指定する必要があります。

Kaggle API キーを生成するには、Kaggle ユーザー プロファイルの [アカウント] タブに移動し、[新しいトークンを作成] を選択します。これにより、API 認証情報を含む kaggle.json ファイルのダウンロードがトリガーされます。

Colab で、左側のペインで [シークレット](🔑?)を選択し、Kaggle ユーザー名と Kaggle API キーを追加します。ユーザー名を KAGGLE_USERNAME という名前で、API キーを KAGGLE_KEY という名前で保存します。

環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

依存関係のインストール

Keras、KerasNLP、その他の依存関係をインストールします。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

バックエンドを選択

Keras は、シンプルさと使いやすさを考慮して設計された高レベルのマルチフレームワーク ディープ ラーニング API です。Keras 3 を使用すると、TensorFlow、JAX、PyTorch の 3 つのバックエンドのいずれかでワークフローを実行できます。

このチュートリアルでは、JAX のバックエンドを構成します。

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

パッケージをインポートする

Keras と KerasNLP をインポートします。

import keras
import keras_nlp

データセットの読み込み

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

データを前処理する。このチュートリアルでは、ノートブックをより迅速に実行するために、1,000 個のトレーニング サンプルのサブセットを使用します。より高品質なファインチューニングを行うには、より多くのトレーニング データを使用することを検討してください。

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

モデルを読み込む

KerasNLP には、多くの一般的なモデル アーキテクチャの実装が用意されています。このチュートリアルでは、因果言語モデル用のエンドツーエンドの Gemma モデルである GemmaCausalLM を使用してモデルを作成します。因果言語モデルは、以前のトークンに基づいて次のトークンを予測します。

from_preset メソッドを使用してモデルを作成します。

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

from_preset メソッドは、事前設定されたアーキテクチャと重みからモデルをインスタンス化します。上記のコードでは、文字列「gemma2_2b_en」で事前設定されたアーキテクチャ(20 億のパラメータを持つ Gemma モデル)を指定しています。

ファインチューニング前の推論

このセクションでは、さまざまなプロンプトを使用してモデルにクエリを実行し、モデルがどのように応答するかを確認します。

ヨーロッパ旅行のプロンプト

ヨーロッパ旅行でおすすめのアクティビティについてモデルにクエリを実行します。

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

モデルは、旅行の計画方法に関する一般的なヒントを返します。

ELI5 光合成プロンプト

5 歳の子供でも理解できる簡単な言葉で光合成を説明するようにモデルに指示します。

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

モデルの回答には、子供が理解しにくい単語(クロロフィルなど)が含まれている。

LoRA ファインチューニング

モデルからより良いレスポンスを得るには、Databricks Dolly 15k データセットを使用し、Low Rank Adaptation(LoRA)でモデルを微調整します。

LoRA ランクは、LLM の元の重みに追加されるトレーニング可能な行列の次元を決定します。微調整の表現力と精度を制御します。

ランクが高いほど、より詳細な変更が可能になりますが、トレーニング可能なパラメータも増えます。ランクが低いほど計算オーバーヘッドは少なくなりますが、適応の精度が低くなる可能性があります。

このチュートリアルでは、LoRA ランク 4 を使用します。実際には、比較的小さなランク(4、8、16 など)から始めます。これは、テストに計算効率が高い方法です。このランクを使用してモデルをトレーニングし、タスクのパフォーマンスの向上を評価します。その後のテストでランクを徐々に増やし、パフォーマンスがさらに向上するかどうかを確認します。

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

LoRA を有効にすると、トレーニング可能なパラメータの数は大幅に減少します(26 億から 290 万に減少)。

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

NVIDIA GPU での混合精度のファインチューニングに関する注意事項

ファインチューニングには完全な精度をおすすめします。NVIDIA GPU でファインチューニングを行う場合は、混合精度(keras.mixed_precision.set_global_policy('mixed_bfloat16'))を使用して、トレーニングの品質への影響を最小限に抑えながらトレーニングを高速化できます。混合精度のファインチューニングではメモリ使用量が増えるため、大規模な GPU でのみ有用です。

推論では、混合精度は適用されませんが、半精度(keras.config.set_floatx("bfloat16"))が機能し、メモリを節約できます。

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

ファインチューニング後の推論

ファインチューニング後、レスポンスはプロンプトで指定された手順に従います。

ヨーロッパ旅行のプロンプト

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

モデルは、ヨーロッパの観光スポットをおすすめするようになりました。

ELI5 光合成プロンプト

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

モデルは、光合成をより簡単に説明できるようになりました。

このチュートリアルでは、デモ目的で、データセットの小さなサブセットで 1 つのエポックのみ、低い LoRA ランク値でモデルをファインチューニングします。ファインチューニングされたモデルからより良いレスポンスを得るには、以下をテストします。

  1. ファインチューニング データセットのサイズを増やす
  2. より多くのステップ(エポック)のトレーニング
  3. LoRA ランクを高く設定する
  4. learning_rateweight_decay などのハイパーパラメータ値の変更。

まとめと次のステップ

このチュートリアルでは、KerasNLP を使用して Gemma モデルの LoRA ファインチューニングについて説明しました。次のドキュメントをご覧ください。