ปรับแต่ง Gemma สำหรับงานด้านวิสัยทัศน์โดยใช้ Hugging Face Transformers และ QLoRA

ดูที่ ai.google.dev เรียกใช้ใน Google Colab เรียกใช้ใน Kaggle เปิดใน Vertex AI ดูแหล่งข้อมูลใน GitHub

คู่มือนี้จะแนะนำวิธีปรับแต่ง Gemma ในชุดข้อมูลรูปภาพและข้อความที่กำหนดเองสำหรับงานด้านวิชันซิสเต็ม (การสร้างคำอธิบายผลิตภัณฑ์) โดยใช้ Transformers และ TRL ของ Hugging Face คุณจะได้เรียนรู้:

  • Quantized Low-Rank Adaptation (QLoRA) คืออะไร
  • ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์
  • สร้างและเตรียมชุดข้อมูลการปรับแต่ง
  • ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer
  • ทดสอบการอนุมานโมเดลและสร้างรายละเอียดผลิตภัณฑ์จากรูปภาพและข้อความ

Quantized Low-Rank Adaptation (QLoRA) คืออะไร

คู่มือนี้แสดงการใช้ Quantized Low-Rank Adaptation (QLoRA) ซึ่งกลายเป็นวิธีที่ได้รับความนิยมในการปรับแต่ง LLM อย่างมีประสิทธิภาพ เนื่องจากช่วยลดข้อกำหนดด้านทรัพยากรการคำนวณในขณะที่ยังคงประสิทธิภาพสูงไว้ได้ ใน QLoRA โมเดลที่ฝึกไว้ล่วงหน้าจะได้รับการควอนไทซ์เป็น 4 บิตและมีการตรึงน้ำหนัก จากนั้นจะมีการแนบเลเยอร์อะแดปเตอร์ที่ฝึกได้ (LoRA) และจะมีการฝึกเฉพาะเลเยอร์อะแดปเตอร์ หลังจากนั้น คุณจะผสานน้ำหนักของอะแดปเตอร์กับโมเดลพื้นฐานหรือเก็บไว้เป็นอะแดปเตอร์แยกต่างหากก็ได้

ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์

ขั้นตอนแรกคือการติดตั้งไลบรารี Hugging Face ซึ่งรวมถึง TRL และชุดข้อมูลเพื่อปรับแต่งโมเดลแบบเปิด

# Install Pytorch & other libraries
%pip install torch tensorboard torchvision

# Install Transformers
%pip install transformers

# Install Hugging Face libraries
%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf pillow sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#%pip install flash-attn

หมายเหตุ: หากใช้ GPU ที่มีสถาปัตยกรรม Ampere (เช่น NVIDIA L4) หรือใหม่กว่า คุณจะใช้ Flash Attention ได้ Flash Attention เป็นวิธีที่ช่วยเพิ่มความเร็วในการคำนวณได้อย่างมาก และลดการใช้งานหน่วยความจำจากกำลังสองเป็นเชิงเส้นในความยาวของลำดับ ซึ่งช่วยเร่งการฝึกได้สูงสุด 3 เท่า ดูข้อมูลเพิ่มเติมได้ที่ FlashAttention

คุณต้องมีโทเค็น Hugging Face ที่ใช้งานได้จึงจะเผยแพร่โมเดลได้ หากคุณเรียกใช้ภายใน Google Colab คุณจะใช้โทเค็น Hugging Face ได้อย่างปลอดภัยโดยใช้ความลับของ Colab หรือจะตั้งค่าโทเค็นโดยตรงในเมธอด login ก็ได้ ตรวจสอบว่าโทเค็นมีสิทธิ์เข้าถึงแบบเขียนด้วย เนื่องจากคุณจะพุชโมเดลไปยัง Hub ในระหว่างการฝึก

# Login into Hugging Face Hub
from huggingface_hub import login
login()

สร้างและเตรียมชุดข้อมูลการปรับแต่ง

เมื่อปรับแต่ง LLM สิ่งสำคัญคือคุณต้องทราบกรณีการใช้งานและงานที่ต้องการแก้ไข ซึ่งจะช่วยให้คุณสร้างชุดข้อมูลเพื่อปรับแต่งโมเดลได้ หากยังไม่ได้กำหนดกรณีการใช้งาน คุณอาจต้องกลับไปที่กระดานออกแบบ

ตัวอย่างเช่น คู่มือนี้มุ่งเน้นไปที่กรณีการใช้งานต่อไปนี้

  • การปรับแต่งโมเดล Gemma เพื่อสร้างรายละเอียดผลิตภัณฑ์ที่กระชับและเพิ่มประสิทธิภาพ SEO สำหรับแพลตฟอร์มอีคอมเมิร์ซ โดยเฉพาะอย่างยิ่งการค้นหาบนอุปกรณ์เคลื่อนที่

คู่มือนี้ใช้ชุดข้อมูล philschmid/amazon-product-descriptions-vlm ซึ่งเป็นชุดข้อมูลคำอธิบายผลิตภัณฑ์ของ Amazon รวมถึงรูปภาพและหมวดหมู่ผลิตภัณฑ์

Hugging Face TRL รองรับการสนทนาแบบมัลติโมดัล ส่วนสำคัญคือบทบาท "image" ซึ่งจะบอกคลาสการประมวลผลว่าควรโหลดรูปภาพ โครงสร้างควรเป็นดังนี้

{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}

ตอนนี้คุณสามารถใช้ไลบรารีชุดข้อมูลของ Hugging Face เพื่อโหลดชุดข้อมูลและสร้างเทมเพลตพรอมต์เพื่อรวมรูปภาพ ชื่อผลิตภัณฑ์ และหมวดหมู่ รวมถึงเพิ่มข้อความระบบได้แล้ว ชุดข้อมูลประกอบด้วยรูปภาพเป็นออบเจ็กต์Pil.Image

from datasets import load_dataset
from PIL import Image

# System message for the assistant
system_message = "You are an expert product description writer for Amazon."

# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
{product}
</PRODUCT>

<CATEGORY>
{category}
</CATEGORY>
"""

# Convert dataset to OAI messages
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                #"content": [{"type": "text", "text": system_message}],
                "content": system_message,
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_prompt.format(
                            product=sample["Product Name"],
                            category=sample["Category"],
                        ),
                    },
                    {
                        "type": "image",
                        "image": sample["image"],
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["description"]}],
            },
        ],
    }

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

# Load dataset from the hub
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")
dataset = dataset.train_test_split(test_size=0.1)

# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset_train = [format_data(sample) for sample in dataset["train"]]
dataset_test = [format_data(sample) for sample in dataset["test"]]

print(dataset_train[345]["messages"])
README.md: 0.00B [00:00, ?B/s]
data/train-00000-of-00001.parquet:   0%|          | 0.00/47.6M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/1345 [00:00<?, ? examples/s]
[{'role': 'system', 'content': 'You are an expert product description writer for Amazon.'}, {'role': 'user', 'content': [{'type': 'text', 'text': "Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\nOnly return description. The description should be SEO optimized and for a better mobile search experience.\n\n<PRODUCT>\nRazor Agitator BMX/Freestyle Bike, 20-Inch\n</PRODUCT>\n\n<CATEGORY>\nSports & Outdoors | Outdoor Recreation | Cycling | Kids' Bikes & Accessories | Kids' Bikes\n</CATEGORY>\n"}, {'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x413 at 0x7B7250181790>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Conquer the streets with the Razor Agitator BMX Bike! This 20-inch freestyle bike is built for young riders ready to take on any challenge. Durable frame, responsive handling – perfect for tricks and cruising.  Get yours today!'}]}]

ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer

ตอนนี้คุณพร้อมที่จะปรับแต่งโมเดลแล้ว SFTTrainer ของ TRL จาก Hugging Face ช่วยให้การดูแลการปรับแต่ง LLM แบบเปิดเป็นเรื่องง่าย SFTTrainer เป็นคลาสย่อยของ Trainer จากไลบรารี transformers และรองรับฟีเจอร์ทั้งหมดเหมือนกัน ซึ่งรวมถึงการบันทึก การประเมิน และการตรวจสอบ แต่เพิ่มฟีเจอร์เพิ่มเติมเพื่ออำนวยความสะดวกในการใช้งาน ซึ่งรวมถึงสิ่งต่อไปนี้

  • การจัดรูปแบบชุดข้อมูล รวมถึงรูปแบบการสนทนาและคำสั่ง
  • การฝึกโดยพิจารณาเฉพาะการตอบคำถามเสร็จสมบูรณ์เท่านั้น โดยไม่สนใจพรอมต์
  • การแพ็กชุดข้อมูลเพื่อการฝึกที่มีประสิทธิภาพมากขึ้น
  • รองรับการปรับแต่งพารามิเตอร์อย่างมีประสิทธิภาพ (PEFT) รวมถึง QLoRA
  • เตรียมโมเดลและโทเค็นไนเซอร์สำหรับการปรับแต่งแบบสนทนา (เช่น การเพิ่มโทเค็นพิเศษ)

โค้ดต่อไปนี้จะโหลดโมเดลและโทเค็นไนเซอร์ Gemma จาก Hugging Face และเริ่มต้นการกำหนดค่าการวัดปริมาณ

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-4-E2B" # @param ["google/gemma-4-E2B","google/gemma-4-E4B"] {"allow-input":true}

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["dtype"],
    bnb_4bit_quant_storage=model_kwargs["dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-4-E2B-it") # Load the Instruction Tokenizer to use the official Gemma template
config.json: 0.00B [00:00, ?B/s]
model.safetensors:   0%|          | 0.00/10.2G [00:00<?, ?B/s]
Loading weights:   0%|          | 0/2011 [00:00<?, ?it/s]
generation_config.json:   0%|          | 0.00/149 [00:00<?, ?B/s]
processor_config.json: 0.00B [00:00, ?B/s]
chat_template.jinja: 0.00B [00:00, ?B/s]
config.json: 0.00B [00:00, ?B/s]
tokenizer_config.json: 0.00B [00:00, ?B/s]
tokenizer.json:   0%|          | 0.00/32.2M [00:00<?, ?B/s]

SFTTrainer รองรับการผสานรวมในตัวกับ peft ซึ่งช่วยให้ปรับแต่ง LLM ได้อย่างมีประสิทธิภาพโดยใช้ QLoRA คุณเพียงแค่ต้องสร้าง LoraConfig และมอบให้แก่ผู้ฝึกสอน

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"], # make sure to save the lm_head and embed_tokens as you train the special tokens
    ensure_weight_tying=True,
)

ก่อนที่จะเริ่มการฝึก คุณต้องกำหนดไฮเปอร์พารามิเตอร์ที่ต้องการใช้ใน SFTConfig และ collate_fn ที่กำหนดเองเพื่อจัดการการประมวลผลภาพ collate_fn จะแปลงข้อความที่มีข้อความและรูปภาพเป็นรูปแบบที่โมเดลเข้าใจได้

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-product-description",     # directory to save and repository id
    num_train_epochs=3,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    eval_strategy="epoch",                      # evaluate checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
    remove_unused_columns = False               # important for collator
)

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == processor.tokenizer.boi_token_id] = -100
    labels[labels == processor.tokenizer.image_token_id] = -100
    labels[labels == processor.tokenizer.eoi_token_id] = -100

    batch["labels"] = labels
    return batch

ตอนนี้คุณมีองค์ประกอบที่ใช้สร้างสรรค์ทุกอย่างที่จำเป็นในการสร้าง SFTTrainer เพื่อเริ่มฝึกโมเดลแล้ว

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset_train,
    eval_dataset=dataset_test,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

เริ่มการฝึกโดยเรียกใช้เมธอด train()

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.

ก่อนที่จะทดสอบโมเดล โปรดตรวจสอบว่าได้เพิ่มพื้นที่ว่างในหน่วยความจำแล้ว

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

เมื่อใช้ QLoRA คุณจะฝึกเฉพาะอแดปเตอร์ ไม่ใช่โมเดลทั้งหมด ซึ่งหมายความว่าเมื่อบันทึกโมเดลระหว่างการฝึก คุณจะบันทึกเฉพาะน้ำหนักของอแดปเตอร์ ไม่ใช่โมเดลทั้งหมด หากต้องการบันทึกโมเดลแบบเต็ม ซึ่งจะช่วยให้ใช้งานกับสแต็กการแสดงผล เช่น vLLM หรือ TGI ได้ง่ายขึ้น คุณสามารถผสานน้ำหนักของอแดปเตอร์เข้ากับน้ำหนักของโมเดลได้โดยใช้เมธอด merge_and_unload จากนั้นบันทึกโมเดลด้วยเมธอด save_pretrained ซึ่งจะบันทึกโมเดลเริ่มต้นที่ใช้สำหรับการอนุมานได้

from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
Loading weights:   0%|          | 0/2011 [00:00<?, ?it/s]
Writing model shards:   0%|          | 0/5 [00:00<?, ?it/s]
['merged_model/processor_config.json']

ทดสอบการอนุมานโมเดลและสร้างรายละเอียดผลิตภัณฑ์

หลังจากฝึกแล้ว คุณจะต้องประเมินและทดสอบโมเดล คุณโหลดตัวอย่างต่างๆ จากชุดข้อมูลทดสอบและประเมินโมเดลในตัวอย่างเหล่านั้นได้

model_id = "merged_model"

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  model_id,
  device_map="auto",
  dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
Loading weights:   0%|          | 0/2012 [00:00<?, ?it/s]
The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.

คุณสามารถทดสอบการอนุมานได้โดยระบุชื่อผลิตภัณฑ์ หมวดหมู่ และรูปภาพ sample มาพร้อมฟิกเกอร์แอ็กชันของมาร์เวล

import requests
from PIL import Image

# Test sample with Product Name, Category and Image
sample = {
  "product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
  "category": "Toys & Games | Toy Figures & Playsets | Action Figures",
  "image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
}

def generate_description(sample, model, processor):
    # Convert sample into messages and then apply the chat template
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": [
            {"type": "image","image": sample["image"]},
            {"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
        ]},
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    print(text)
    # Process the image and text
    image_inputs = process_vision_info(messages)
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    # Move the inputs to the device
    inputs = inputs.to(model.device)

    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<turn|>")]
    generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]

# generate the description
description = generate_description(sample, model, processor)
print("MODEL OUTPUT>> \n")
print(description)
<bos><|turn>system
You are an expert product description writer for Amazon.<turn|>
<|turn>user


<|image|>

Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur
</PRODUCT>

<CATEGORY>
Toys & Games | Toy Figures & Playsets | Action Figures
</CATEGORY><turn|>
<|turn>model

MODEL OUTPUT>> 

Enhance your collection with the Marvel Avengers - Avengers Assemble Ultron-Comforter Set! This soft and cuddly blanket and pillowcase feature everyone's favorite Avengers, Iron Man, and his loyal companion War Machine. Officially licensed by Marvel.  Bring home the heroic team!

สรุปและขั้นตอนถัดไป

บทแนะนำนี้ครอบคลุมวิธีปรับแต่งโมเดล Gemma สำหรับงานด้านวิชันซิสเต็มโดยใช้ TRL และ QLoRA โดยเฉพาะอย่างยิ่งสำหรับการสร้างคำอธิบายผลิตภัณฑ์ ดูเอกสารต่อไปนี้