이 가이드에서는 Hugging Face Transformers 및 TRL을 사용하여 비전 작업 (제품 설명 생성)을 위한 맞춤 이미지 및 텍스트 데이터 세트에서 Gemma를 미세 조정하는 방법을 안내합니다. 학습 내용
- 정량화된 LoRA (QLoRA)란 무엇인가요?
- 개발 환경 설정
- 비전 태스크를 위한 미세 조정 데이터 세트 만들기 및 준비
- TRL 및 SFTTrainer를 사용하여 Gemma 미세 조정
- 모델 추론을 테스트하고 이미지와 텍스트에서 제품 설명을 생성합니다.
정량화된 LoRA (QLoRA)란 무엇인가요?
이 가이드에서는 LLM을 효율적으로 미세 조정하는 인기 있는 방법으로 떠오른 정량화된 하위 순위 조정 (QLoRA)의 사용을 보여줍니다. QLoRA는 높은 성능을 유지하면서 계산 리소스 요구사항을 줄여주기 때문입니다. QloRA에서는 선행 학습된 모델이 4비트로 양자화되고 가중치가 고정됩니다. 그런 다음 학습 가능한 어댑터 레이어 (LoRA)가 연결되고 어댑터 레이어만 학습됩니다. 그런 다음 어댑터 가중치를 기본 모델과 병합하거나 별도의 어댑터로 유지할 수 있습니다.
개발 환경 설정
첫 번째 단계는 TRL을 비롯한 Hugging Face 라이브러리와 오픈 모델을 미세 조정할 데이터 세트를 설치하는 것입니다.
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard torchvision
# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"
# Install Hugging Face libraries
%pip install --upgrade \
"datasets==3.3.2" \
"accelerate==1.4.0" \
"evaluate==0.4.3" \
"bitsandbytes==0.45.3" \
"trl==0.15.2" \
"peft==0.14.0" \
"pillow==11.1.0" \
protobuf \
sentencepiece
학습을 시작하려면 먼저 Gemma의 사용 약관에 동의해야 합니다. Hugging Face의 모델 페이지 (http://huggingface.co/google/gemma-3-4b-pt 또는 사용 중인 시각 기능이 있는 Gemma 모델의 적절한 모델 페이지)에서 '동의 및 저장소 액세스' 버튼을 클릭하여 라이선스를 수락할 수 있습니다.
라이선스에 동의한 후에는 유효한 Hugging Face 토큰이 있어야 모델에 액세스할 수 있습니다. Google Colab 내에서 실행하는 경우 Colab 보안 비밀을 사용하여 Hugging Face 토큰을 안전하게 사용할 수 있습니다. 그러지 않으면 login
메서드에서 토큰을 직접 설정할 수 있습니다. 학습 중에 모델을 허브에 푸시할 때 토큰에 쓰기 액세스 권한도 있는지 확인합니다.
from google.colab import userdata
from huggingface_hub import login
# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)
미세 조정 데이터 세트 만들기 및 준비
LLM을 미세 조정할 때는 사용 사례와 해결하려는 작업을 파악하는 것이 중요합니다. 이렇게 하면 모델을 미세 조정할 데이터 세트를 만들 수 있습니다. 아직 사용 사례를 정의하지 않았다면 다시 설계 단계로 돌아가야 할 수 있습니다.
예를 들어 이 가이드에서는 다음과 같은 사용 사례에 중점을 둡니다.
- Gemma 모델을 미세 조정하여 전자상거래 플랫폼에 맞춤설정된 간결하고 검색엔진에 최적화된 제품 설명을 생성합니다(특히 모바일 검색에 맞춤설정).
이 가이드에서는 제품 이미지 및 카테고리를 포함한 Amazon 제품 설명 데이터인 philschmid/amazon-product-descriptions-vlm 데이터 세트를 사용합니다.
Hugging Face TRL은 멀티모달 대화를 지원합니다. 중요한 부분은 처리 클래스에 이미지를 로드해야 한다고 알려주는 'image' 역할입니다. 구조는 다음을 따라야 합니다.
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
이제 Hugging Face Datasets 라이브러리를 사용하여 데이터 세트를 로드하고 프롬프트 템플릿을 만들어 이미지, 제품 이름, 카테고리를 결합하고 시스템 메시지를 추가할 수 있습니다. 데이터 세트에는 이미지가 Pil.Image
객체로 포함됩니다.
from datasets import load_dataset
from PIL import Image
# System message for the assistant
system_message = "You are an expert product description writer for Amazon."
# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.
<PRODUCT>
{product}
</PRODUCT>
<CATEGORY>
{category}
</CATEGORY>
"""
# Convert dataset to OAI messages
def format_data(sample):
return {
"messages": [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt.format(
product=sample["Product Name"],
category=sample["Category"],
),
},
{
"type": "image",
"image": sample["image"],
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["description"]}],
},
],
}
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
# Iterate through each conversation
for msg in messages:
# Get content (ensure it's a list)
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
# Check each content element for images
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
# Get the image and convert to RGB
if "image" in element:
image = element["image"]
else:
image = element
image_inputs.append(image.convert("RGB"))
return image_inputs
# Load dataset from the hub
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")
# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset = [format_data(sample) for sample in dataset]
print(dataset[345]["messages"])
TRL 및 SFTTrainer를 사용하여 Gemma 미세 조정
이제 모델을 미세 조정할 준비가 되었습니다. Hugging Face TRL SFTTrainer를 사용하면 개방형 LLM의 미세 조정을 간편하게 감독할 수 있습니다. SFTTrainer
는 transformers
라이브러리의 Trainer
의 서브클래스이며 로깅, 평가, 체크포인트 설정을 비롯한 동일한 모든 기능을 지원하지만 다음과 같은 편의 기능을 추가로 제공합니다.
- 대화형 및 안내 형식을 포함한 데이터 세트 형식 지정
- 프롬프트를 무시하고 완료만 학습
- 더 효율적인 학습을 위해 데이터 세트 패킹
- QloRA를 포함한 매개변수 효율적인 미세 조정 (PEFT) 지원
- 대화형 미세 조정을 위한 모델 및 토큰 생성기 준비 (예: 특수 토큰 추가)
다음 코드는 Hugging Face에서 Gemma 모델과 토큰라이저를 로드하고 정규화 구성을 초기화합니다.
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
# Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
SFTTrainer
는 peft
와의 기본 제공 통합을 지원하므로 QLoRA를 사용하여 LLM을 간편하고 효율적으로 조정할 수 있습니다. LoraConfig
를 만들고 트레이너에게 제공하기만 하면 됩니다.
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
학습을 시작하려면 먼저 SFTConfig
에서 사용할 초매개변수와 비전 처리를 처리할 맞춤 collate_fn
를 정의해야 합니다. collate_fn
는 텍스트와 이미지가 포함된 메시지를 모델이 이해할 수 있는 형식으로 변환합니다.
from trl import SFTConfig
args = SFTConfig(
output_dir="gemma-product-description", # directory to save and repository id
num_train_epochs=1, # number of training epochs
per_device_train_batch_size=1, # batch size per device during training
gradient_accumulation_steps=4, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=5, # log every 5 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="tensorboard", # report metrics to tensorboard
gradient_checkpointing_kwargs={
"use_reentrant": False
}, # use reentrant checkpointing
dataset_text_field="", # need a dummy field for collator
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
)
args.remove_unused_columns = False # important for collator
# Create a data collator to encode text and image pairs
def collate_fn(examples):
texts = []
images = []
for example in examples:
image_inputs = process_vision_info(example["messages"])
text = processor.apply_chat_template(
example["messages"], add_generation_prompt=False, tokenize=False
)
texts.append(text.strip())
images.append(image_inputs)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
labels = batch["input_ids"].clone()
# Mask image tokens
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.special_tokens_map["boi_token"]
)
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
이제 SFTTrainer
를 만들고 모델 학습을 시작하는 데 필요한 모든 구성요소가 준비되었습니다.
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
)
train()
메서드를 호출하여 학습을 시작합니다.
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
# Save the final model again to the Hugging Face Hub
trainer.save_model()
모델을 테스트하기 전에 메모리를 해제해야 합니다.
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
QLoRA를 사용하면 전체 모델이 아닌 어댑터만 학습합니다. 즉, 학습 중에 모델을 저장할 때는 전체 모델이 아닌 어댑터 가중치만 저장됩니다. vLLM 또는 TGI와 같은 서빙 스택에서 더 쉽게 사용할 수 있도록 전체 모델을 저장하려면 merge_and_unload
메서드를 사용하여 어댑터 가중치를 모델 가중치에 병합한 다음 save_pretrained
메서드로 모델을 저장하면 됩니다. 이렇게 하면 추론에 사용할 수 있는 기본 모델이 저장됩니다.
from peft import PeftModel
# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
모델 추론 테스트 및 제품 설명 생성
학습이 완료되면 모델을 평가하고 테스트해야 합니다. 테스트 데이터 세트에서 다양한 샘플을 로드하고 이러한 샘플에서 모델을 평가할 수 있습니다.
import torch
# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
args.output_dir,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(args.output_dir)
제품 이름, 카테고리, 이미지를 제공하여 추론을 테스트할 수 있습니다. sample
에는 마블 액션 피규어가 포함되어 있습니다.
import requests
from PIL import Image
# Test sample with Product Name, Category and Image
sample = {
"product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
"category": "Toys & Games | Toy Figures & Playsets | Action Figures",
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
}
def generate_description(sample, model, processor):
# Convert sample into messages and then apply the chat template
messages = [
{"role": "system", "content": [{"type": "text", "text": system_message}]},
{"role": "user", "content": [
{"type": "image","image": sample["image"]},
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
]},
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process the image and text
image_inputs = process_vision_info(messages)
# Tokenize the text and process the images
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
# Move the inputs to the device
inputs = inputs.to(model.device)
# Generate the output
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
# Trim the generation and decode the output to text
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
# generate the description
description = generate_description(sample, model, processor)
print(description)
요약 및 다음 단계
이 튜토리얼에서는 특히 제품 설명을 생성하기 위해 TRL 및 QLoRA를 사용하여 비전 작업을 위한 Gemma 모델을 미세 조정하는 방법을 설명했습니다. 다음 문서를 확인하세요.
- Gemma 모델로 텍스트를 생성하는 방법을 알아보세요.
- Hugging Face Transformers를 사용하여 텍스트 작업을 위해 Gemma를 미세 조정하는 방법을 알아봅니다.
- Gemma 모델에서 분산 파인 튜닝 및 추론을 실행하는 방법을 알아보세요.
- Vertex AI에서 Gemma 개방형 모델을 사용하는 방법을 알아보세요.
- KerasNLP를 사용하여 Gemma를 미세 조정하고 Vertex AI에 배포하는 방법을 알아봅니다.