ปรับแต่ง Gemma สำหรับงานด้านวิสัยทัศน์โดยใช้ Hugging Face Transformers และ QLoRA

คู่มือนี้จะอธิบายวิธีปรับแต่ง Gemma ในชุดข้อมูลรูปภาพและข้อความที่กําหนดเองสําหรับงานด้านวิสัยทัศน์ (การสร้างคําอธิบายผลิตภัณฑ์) โดยใช้ Transformers และ TRL ของ Hugging Face คุณจะได้เรียนรู้:

  • Quantized Low-Rank Adaptation (QLoRA) คืออะไร
  • ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์
  • สร้างและเตรียมชุดข้อมูลสำหรับการปรับแต่งขั้นสุดท้ายสำหรับงานด้านวิสัยทัศน์
  • ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer
  • ทดสอบการอนุมานโมเดลและสร้างคำอธิบายผลิตภัณฑ์จากรูปภาพและข้อความ

Quantized Low-Rank Adaptation (QLoRA) คืออะไร

คู่มือนี้จะแสดงการใช้ Quantized Low-Rank Adaptation (QLoRA) ซึ่งเป็นวิธีการยอดนิยมในการปรับแต่ง LLM อย่างมีประสิทธิภาพ เนื่องจากช่วยลดความต้องการทรัพยากรการประมวลผลไปพร้อมกับคงประสิทธิภาพไว้สูง ใน QloRA โมเดลที่ผ่านการฝึกอบรมล่วงหน้าจะได้รับการแปลงเป็น 4 บิตและมีการตรึงน้ำหนัก จากนั้นจะแนบเลเยอร์อะแดปเตอร์ที่ฝึกได้ (LoRA) และฝึกเฉพาะเลเยอร์อะแดปเตอร์ หลังจากนั้น คุณสามารถผสานน้ำหนักของอะแดปเตอร์เข้ากับโมเดลพื้นฐานหรือเก็บไว้เป็นอะแดปเตอร์แยกต่างหากก็ได้

ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์

ขั้นตอนแรกคือการติดตั้งไลบรารี Hugging Face ซึ่งรวมถึง TRL และชุดข้อมูลเพื่อปรับแต่งโมเดลแบบเปิด

# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard torchvision

# Install Gemma release branch from Hugging Face
%pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  "pillow==11.1.0" \
  protobuf \
  sentencepiece

คุณต้องยอมรับข้อกำหนดในการใช้งาน Gemma ก่อนจึงจะเริ่มการฝึกได้ คุณสามารถยอมรับใบอนุญาตใน Hugging Face ได้โดยคลิกปุ่ม "ยอมรับและเข้าถึงที่เก็บ" ในหน้าโมเดลที่ http://huggingface.co/google/gemma-3-4b-pt (หรือหน้าโมเดลที่เหมาะสมสำหรับโมเดล Gemma ที่มีความสามารถในการมองเห็นที่คุณใช้อยู่)

หลังจากยอมรับใบอนุญาตแล้ว คุณต้องมีโทเค็น Hugging Face ที่ถูกต้องจึงจะเข้าถึงโมเดลได้ หากเรียกใช้ภายใน Google Colab คุณจะใช้โทเค็น Hugging Face ได้อย่างปลอดภัยโดยใช้ข้อมูลลับของ Colab หรือจะตั้งค่าโทเค็นในเมธอด login โดยตรงก็ได้ ตรวจสอบว่าโทเค็นของคุณมีสิทธิ์เขียนด้วย เนื่องจากคุณจะต้องพุชโมเดลไปยัง Hub ระหว่างการฝึก

from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

สร้างและเตรียมชุดข้อมูลสำหรับการปรับแต่ง

เมื่อปรับแต่ง LLM คุณต้องทราบกรณีการใช้งานและงานที่คุณต้องการแก้ปัญหา ซึ่งจะช่วยคุณสร้างชุดข้อมูลเพื่อปรับแต่งโมเดล หากยังไม่ได้กําหนดกรณีการใช้งาน คุณอาจต้องกลับไปที่กระดานวาดภาพ

ตัวอย่างเช่น คู่มือนี้มุ่งเน้นที่กรณีการใช้งานต่อไปนี้

  • การปรับแต่งโมเดล Gemma เพื่อสร้างรายละเอียดผลิตภัณฑ์ที่กระชับและเพิ่มประสิทธิภาพ SEO สำหรับแพลตฟอร์มอีคอมเมิร์ซ โดยปรับให้เหมาะกับการค้นหาบนอุปกรณ์เคลื่อนที่โดยเฉพาะ

คู่มือนี้ใช้ชุดข้อมูล philschmid/amazon-product-descriptions-vlm ซึ่งเป็นชุดข้อมูลคำอธิบายผลิตภัณฑ์ของ Amazon รวมถึงรูปภาพและหมวดหมู่ผลิตภัณฑ์

Hugging Face TRL รองรับการสนทนาแบบหลายรูปแบบ ส่วนสําคัญคือบทบาท "image" ซึ่งบอกคลาสการประมวลผลว่าควรโหลดรูปภาพ โครงสร้างควรเป็นไปตามรูปแบบต่อไปนี้

{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}

ตอนนี้คุณใช้คลังชุดข้อมูล Hugging Face เพื่อโหลดชุดข้อมูลและสร้างเทมเพลตพรอมต์เพื่อรวมรูปภาพ ชื่อผลิตภัณฑ์ และหมวดหมู่ รวมถึงเพิ่มข้อความของระบบได้แล้ว ชุดข้อมูลมีรูปภาพเป็นออบเจ็กต์Pil.Image

from datasets import load_dataset
from PIL import Image

# System message for the assistant
system_message = "You are an expert product description writer for Amazon."

# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
{product}
</PRODUCT>

<CATEGORY>
{category}
</CATEGORY>
"""

# Convert dataset to OAI messages
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_prompt.format(
                            product=sample["Product Name"],
                            category=sample["Category"],
                        ),
                    },
                    {
                        "type": "image",
                        "image": sample["image"],
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["description"]}],
            },
        ],
    }

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

# Load dataset from the hub
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")

# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset = [format_data(sample) for sample in dataset]

print(dataset[345]["messages"])

ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer

ตอนนี้คุณก็พร้อมที่จะปรับแต่งโมเดลแล้ว Hugging Face TRL SFTTrainer ช่วยให้คุณควบคุมการปรับแต่ง LLM แบบเปิดได้อย่างง่ายดาย SFTTrainer เป็นคลาสย่อยของ Trainer จากไลบรารี transformers และรองรับฟีเจอร์เดียวกันทั้งหมด ซึ่งรวมถึงการบันทึก การประเมิน และการตรวจสอบจุดตรวจสอบ แต่เพิ่มฟีเจอร์อื่นๆ ที่ช่วยให้ใช้งานได้ง่ายขึ้น ซึ่งได้แก่

  • การจัดรูปแบบชุดข้อมูล รวมถึงรูปแบบการสนทนาและคำสั่ง
  • การฝึกอบรมเกี่ยวกับการดำเนินการให้เสร็จสมบูรณ์เท่านั้น โดยจะไม่สนใจพรอมต์
  • การแพ็กชุดข้อมูลเพื่อให้การฝึกมีประสิทธิภาพมากขึ้น
  • การรองรับการปรับแต่งแบบละเอียดที่มีประสิทธิภาพของพารามิเตอร์ (PEFT) รวมถึง QloRA
  • เตรียมโมเดลและตัวแยกวิเคราะห์เพื่อปรับแต่งการสนทนาแบบละเอียด (เช่น การเพิ่มโทเค็นพิเศษ)

โค้ดต่อไปนี้จะโหลดโมเดล Gemma และตัวแยกวิเคราะห์จาก Hugging Face และเริ่มต้นการกําหนดค่าการแปลงค่าเป็นจำนวนเต็ม

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

SFTTrainer รองรับการผสานรวมในตัวกับ peft ซึ่งช่วยให้ปรับแต่ง LLM ได้อย่างมีประสิทธิภาพโดยใช้ QLoRA คุณเพียงแค่สร้าง LoraConfig แล้วส่งให้เทรนเนอร์

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

คุณต้องกําหนดไฮเปอร์พารามิเตอร์ที่ต้องการใช้ใน SFTConfig และ collate_fn ที่กําหนดเองเพื่อจัดการการประมวลผลภาพก่อนจึงจะเริ่มการฝึกได้ collate_fn จะแปลงข้อความที่มีข้อความและรูปภาพเป็นรูปแบบที่โมเดลเข้าใจ

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-product-description",     # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    gradient_accumulation_steps=4,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch

ตอนนี้คุณมีองค์ประกอบพื้นฐานทั้งหมดในการสร้าง SFTTrainer เพื่อเริ่มการฝึกโมเดลแล้ว

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

เริ่มการฝึกโดยเรียกใช้เมธอด train()

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

อย่าลืมเพิ่มพื้นที่ว่างในหน่วยความจำก่อนจึงจะทดสอบโมเดลได้

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

เมื่อใช้ QLoRA คุณจะฝึกเฉพาะอะแดปเตอร์เท่านั้น ไม่ใช่ทั้งโมเดล ซึ่งหมายความว่าเมื่อบันทึกโมเดลระหว่างการฝึก คุณบันทึกเฉพาะน้ำหนักของอะแดปเตอร์เท่านั้น ไม่ใช่โมเดลทั้งหมด หากต้องการบันทึกโมเดลแบบเต็ม ซึ่งทําให้ใช้งานกับสแต็กการแสดงโฆษณา เช่น vLLM หรือ TGI ได้ง่ายขึ้น คุณสามารถผสานน้ำหนักของอะแดปเตอร์เข้ากับน้ำหนักของโมเดลได้โดยใช้เมธอด merge_and_unload จากนั้นบันทึกโมเดลด้วยเมธอด save_pretrained ซึ่งจะบันทึกโมเดลเริ่มต้นไว้เพื่อใช้สำหรับการอนุมาน

from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

ทดสอบการอนุมานโมเดลและสร้างรายละเอียดผลิตภัณฑ์

หลังจากการฝึกเสร็จสิ้นแล้ว คุณจะต้องประเมินและทดสอบโมเดล คุณสามารถโหลดตัวอย่างต่างๆ จากชุดทดสอบและประเมินโมเดลกับตัวอย่างเหล่านั้นได้

import torch

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  args.output_dir,
  device_map="auto",
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(args.output_dir)

คุณสามารถทดสอบการอนุมานได้โดยระบุชื่อผลิตภัณฑ์ หมวดหมู่ และรูปภาพ sample มีฟิกเกอร์ Marvel

import requests
from PIL import Image

# Test sample with Product Name, Category and Image
sample = {
  "product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
  "category": "Toys & Games | Toy Figures & Playsets | Action Figures",
  "image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
}

def generate_description(sample, model, processor):
    # Convert sample into messages and then apply the chat template
    messages = [
        {"role": "system", "content": [{"type": "text", "text": system_message}]},
        {"role": "user", "content": [
            {"type": "image","image": sample["image"]},
            {"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
        ]},
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    # Process the image and text
    image_inputs = process_vision_info(messages)
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    # Move the inputs to the device
    inputs = inputs.to(model.device)
    
    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
    generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]

# generate the description
description = generate_description(sample, model, processor)
print(description)

สรุปและขั้นตอนถัดไป

บทแนะนํานี้อธิบายวิธีปรับแต่งโมเดล Gemma สําหรับงานด้านวิสัยทัศน์โดยใช้ TRL และ QLoRA โดยเฉพาะสําหรับการสร้างคําอธิบายผลิตภัณฑ์ โปรดอ่านเอกสารต่อไปนี้