LoRA를 사용하여 Keras에서 Gemma 모델 미세 조정

ai.google.dev에서 보기 Google Colab에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

개요

Gemma는 Gemini 모델을 만드는 데 사용된 것과 동일한 연구 및 기술로 빌드된 최첨단 경량 개방형 모델군입니다.

Gemma와 같은 대규모 언어 모델 (LLM)은 다양한 NLP 태스크에 효과적임이 입증되었습니다. LLM은 먼저 대규모 텍스트 코퍼스를 사용하여 자기 감독 방식으로 선행 학습됩니다. 사전 학습을 통해 LLM은 단어 간의 통계적 관계와 같은 범용 지식을 학습할 수 있습니다. 그런 다음 LLM을 도메인별 데이터로 미세 조정하여 감정 분석과 같은 하류 태스크를 실행할 수 있습니다.

LLM은 크기가 매우 큽니다 (수십억 개 정도의 매개변수). 일반적인 미세 조정 데이터 세트는 사전 학습 데이터 세트보다 상대적으로 훨씬 작기 때문에 대부분의 애플리케이션에는 전체 미세 조정 (모델의 모든 매개변수를 업데이트함)이 필요하지 않습니다.

LoRA (Low Rank Adaptation)는 모델의 가중치를 고정하고 더 적은 수의 새 가중치를 모델에 삽입하여 다운스트림 태스크의 학습 가능한 매개변수 수를 크게 줄이는 미세 조정 기법입니다. 이렇게 하면 LoRA를 사용한 학습이 훨씬 더 빠르고 메모리 효율이 높아지며 모델 출력의 품질을 유지하면서 더 작은 모델 가중치 (수백 MB)를 생성할 수 있습니다.

이 튜토리얼에서는 KerasNLP를 사용하여 Databricks Dolly 15k 데이터 세트를 사용하여 Gemma 2B 모델에서 LoRA 미세 조정을 수행하는 방법을 안내합니다. 이 데이터 세트에는 LLM 미세 조정을 위해 특별히 설계된 고품질의 인간이 생성한 프롬프트 / 응답 쌍 15,000개가 포함되어 있습니다.

설정

Gemma 액세스 권한 획득하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정에서 설정 안내를 완료해야 합니다. Gemma 설정 안내에서는 다음을 수행하는 방법을 보여줍니다.

  • kaggle.com에서 Gemma에 액세스합니다.
  • Gemma 2B 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택합니다.
  • Kaggle 사용자 이름과 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

런타임 선택

이 튜토리얼을 완료하려면 Gemma 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임이 필요합니다. 이 경우 T4 GPU를 사용할 수 있습니다.

  1. Colab 창의 오른쪽 상단에서 ▾ (추가 연결 옵션)을 선택합니다.
  2. 런타임 유형 변경을 선택합니다.
  3. 하드웨어 가속기에서 T4 GPU를 선택합니다.

API 키 구성

Gemma를 사용하려면 Kaggle 사용자 이름과 Kaggle API 키를 제공해야 합니다.

Kaggle API 키를 생성하려면 Kaggle 사용자 프로필의 계정 탭으로 이동하여 새 토큰 만들기를 선택합니다. 이렇게 하면 API 사용자 인증 정보가 포함된 kaggle.json 파일이 다운로드됩니다.

Colab에서 왼쪽 창의 보안 비밀 (🔑)을 선택하고 Kaggle 사용자 이름과 Kaggle API 키를 추가합니다. 사용자 이름은 KAGGLE_USERNAME라는 이름으로, API 키는 KAGGLE_KEY라는 이름으로 저장합니다.

환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

종속 항목 설치

Keras, KerasNLP, 기타 종속 항목을 설치합니다.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

백엔드 선택

Keras는 단순성과 사용 편의성을 위해 설계된 고급 멀티 프레임워크 딥 러닝 API입니다. Keras 3를 사용하면 TensorFlow, JAX, PyTorch 중 하나의 백엔드에서 워크플로를 실행할 수 있습니다.

이 튜토리얼에서는 JAX 백엔드를 구성합니다.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

패키지 가져오기

Keras 및 KerasNLP를 가져옵니다.

import keras
import keras_nlp

데이터 세트 로드

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

데이터를 사전 처리합니다. 이 튜토리얼에서는 1,000개의 학습 예시 중 일부를 사용하여 노트북을 더 빠르게 실행합니다. 고품질 미세 조정을 위해 더 많은 학습 데이터를 사용하는 것이 좋습니다.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

모델 로드

KerasNLP는 많이 사용되는 여러 모델 아키텍처의 구현을 제공합니다. 이 튜토리얼에서는 인과 언어 모델링을 위한 엔드 투 엔드 Gemma 모델인 GemmaCausalLM를 사용하여 모델을 만듭니다. 인과 언어 모델은 이전 토큰을 기반으로 다음 토큰을 예측합니다.

from_preset 메서드를 사용하여 모델을 만듭니다.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

from_preset 메서드는 사전 설정된 아키텍처와 가중치에서 모델을 인스턴스화합니다. 위의 코드에서 'gemma2_2b_en' 문자열은 사전 설정된 아키텍처(20억 개의 매개변수를 가진 Gemma 모델)를 지정합니다.

미세 조정 전 추론

이 섹션에서는 다양한 프롬프트로 모델에 쿼리하여 모델이 어떻게 응답하는지 확인합니다.

유럽 여행 프롬프트

유럽 여행에서 할 일에 대한 추천을 모델에 쿼리합니다.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

모델은 여행 계획 방법에 관한 일반적인 도움말로 응답합니다.

ELI5 광합성 프롬프트

5세 어린이가 이해할 수 있을 만큼 간단한 용어로 광합성을 설명하도록 모델에 프롬프트를 제공합니다.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

모델 대답에는 어린이가 이해하기 쉽지 않은 단어(예: 엽록소)가 포함되어 있습니다.

LoRA 미세 조정

모델에서 더 나은 응답을 얻으려면 Databricks Dolly 15k 데이터 세트를 사용하여 낮은 순위 적응 (LoRA)으로 모델을 미세 조정합니다.

LoRA 순위는 LLM의 원래 가중치에 추가되는 학습 가능한 행렬의 차원을 결정합니다. 미세 조정의 표현력과 정밀도를 제어합니다.

순위가 높을수록 더 자세한 변경이 가능하지만 학습 가능한 매개변수도 더 많다는 의미입니다. 순위가 낮을수록 계산 오버헤드는 줄어들지만 적응 정확성은 떨어질 수 있습니다.

이 튜토리얼에서는 LoRA 순위 4를 사용합니다. 실제로는 비교적 작은 순위 (예: 4, 8, 16)로 시작합니다. 이는 실험에 있어 계산적으로 효율적입니다. 이 순위로 모델을 학습시키고 태스크의 성능 향상을 평가합니다. 후속 실험에서 순위를 점진적으로 높이고 실적이 더 개선되는지 확인합니다.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

LoRA를 사용 설정하면 학습 가능한 매개변수 수가 26억 개에서 290만 개로 크게 줄어듭니다.

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

NVIDIA GPU의 혼합 정밀도 미세 조정에 관한 참고사항

미세 조정에는 전체 정밀도가 권장됩니다. NVIDIA GPU에서 미세 조정할 때 혼합 정밀도 (keras.mixed_precision.set_global_policy('mixed_bfloat16'))를 사용하면 학습 품질에 미치는 영향을 최소화하면서 학습 속도를 높일 수 있습니다. 혼합 정밀도 미세 조정은 더 많은 메모리를 사용하므로 큰 GPU에서만 유용합니다.

추론의 경우 혼합 정밀도는 적용되지 않지만 절반 정밀도 (keras.config.set_floatx("bfloat16"))가 작동하여 메모리를 절약합니다.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

미세 조정 후 추론

미세 조정 후에는 프롬프트에 제공된 안내에 따라 응답이 이루어집니다.

유럽 여행 프롬프트

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

이제 모델이 유럽에서 가볼 만한 장소를 추천합니다.

ELI5 광합성 프롬프트

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

이제 모델은 광합성을 더 쉽게 설명합니다.

이 튜토리얼에서는 데모 목적으로 단 1번의 에포크 동안 LoRA 순위 값이 낮은 데이터 세트의 소규모 하위 집합에서 모델을 미세 조정합니다. 미세 조정된 모델에서 더 나은 대답을 얻으려면 다음을 실험해 보세요.

  1. 미세 조정 데이터 세트 크기 늘리기
  2. 더 많은 단계 (세대) 학습
  3. 더 높은 LoRA 순위 설정
  4. learning_rateweight_decay와 같은 초매개변수 값을 수정합니다.

요약 및 다음 단계

이 튜토리얼에서는 KerasNLP를 사용하여 Gemma 모델에서 LoRA 미세 조정을 다뤘습니다. 다음 문서를 확인하세요.