|
|
在 Google Colab 中运行
|
|
|
在 GitHub 上查看源代码
|
本指南将引导您了解如何使用 Hugging Face Transformers 和 TRL 在自定义图片和文本数据集上微调 Gemma,以完成视觉任务(生成商品说明)。您将了解:
- 什么是量化低秩适应 (QLoRA)
- 设置开发环境
- 创建和准备微调数据集
- 使用 TRL 和 SFTTrainer 微调 Gemma
- 测试模型推理并根据图片和文本生成商品说明。
什么是量化低秩适应 (QLoRA)
本指南演示了 量化低秩适应 (QLoRA) 的使用。QLoRA 是一种新兴的热门方法,可高效地微调 LLM,因为它在保持高性能的同时降低了计算资源要求。在 QloRA 中,预训练模型被量化为 4 位,并且权重被冻结。然后,可训练的适配器层 (LoRA) 会被附加,并且仅训练适配器层。之后,适配器权重可以与基本模型合并,也可以保留为单独的适配器。
设置开发环境
第一步是安装 Hugging Face 库,包括 TRL 和数据集,以微调开放模型。
# Install Pytorch & other libraries
%pip install torch tensorboard torchvision
# Install Transformers
%pip install transformers
# Install Hugging Face libraries
%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf pillow sentencepiece
# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#%pip install flash-attn
注意:如果您使用的是 Ampere 架构(例如 NVIDIA L4)或更新的 GPU,则可以使用 Flash 注意力机制。Flash 注意力机制是一种可显著加快计算速度并减少内存使用量的方法,可将内存使用量从序列长度的二次方降低到线性,从而将训练速度提高 3 倍。如需了解详情,请参阅 FlashAttention。
您需要有效的 Hugging Face 令牌才能发布模型。如果您在 Google Colab 中运行,则可以使用 Colab 密钥安全地使用 Hugging Face 令牌,否则可以直接在 login 方法中设置令牌。请确保您的令牌也具有写入权限,因为您会在训练期间将模型推送到 Hub。
# Login into Hugging Face Hub
from huggingface_hub import login
login()
创建和准备微调数据集
微调 LLM 时,了解您的使用场景和要解决的任务非常重要。这有助于您创建数据集来微调模型。如果您尚未定义使用场景,可能需要重新开始。
例如,本指南重点介绍以下使用场景:
- 微调 Gemma 模型,为电子商务平台生成简洁、经过 SEO 优化的商品说明,专门针对移动搜索。
本指南使用 philschmid/amazon-product-descriptions-vlm 数据集,这是一个包含 Amazon 商品说明的数据集,包括商品图片和类别。
Hugging Face TRL 支持多模态对话。重要部分是“image”角色,它会告知处理类应加载图片。结构应遵循:
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
现在,您可以使用 Hugging Face Datasets 库加载数据集,并创建提示模板以组合图片、商品名称和类别,并添加系统消息。数据集包含图片作为 Pil.Image 对象。
from datasets import load_dataset
from PIL import Image
# System message for the assistant
system_message = "You are an expert product description writer for Amazon."
# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.
<PRODUCT>
{product}
</PRODUCT>
<CATEGORY>
{category}
</CATEGORY>
"""
# Convert dataset to OAI messages
def format_data(sample):
return {
"messages": [
{
"role": "system",
#"content": [{"type": "text", "text": system_message}],
"content": system_message,
},
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt.format(
product=sample["Product Name"],
category=sample["Category"],
),
},
{
"type": "image",
"image": sample["image"],
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["description"]}],
},
],
}
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
# Iterate through each conversation
for msg in messages:
# Get content (ensure it's a list)
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
# Check each content element for images
for element in content:
if isinstance(element, dict) and (
"image" in element or element.get("type") == "image"
):
# Get the image and convert to RGB
if "image" in element:
image = element["image"]
else:
image = element
image_inputs.append(image.convert("RGB"))
return image_inputs
# Load dataset from the hub
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")
dataset = dataset.train_test_split(test_size=0.1)
# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset_train = [format_data(sample) for sample in dataset["train"]]
dataset_test = [format_data(sample) for sample in dataset["test"]]
print(dataset_train[345]["messages"])
README.md: 0.00B [00:00, ?B/s]
data/train-00000-of-00001.parquet: 0%| | 0.00/47.6M [00:00<?, ?B/s]
Generating train split: 0%| | 0/1345 [00:00<?, ? examples/s]
[{'role': 'system', 'content': 'You are an expert product description writer for Amazon.'}, {'role': 'user', 'content': [{'type': 'text', 'text': "Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\nOnly return description. The description should be SEO optimized and for a better mobile search experience.\n\n<PRODUCT>\nRazor Agitator BMX/Freestyle Bike, 20-Inch\n</PRODUCT>\n\n<CATEGORY>\nSports & Outdoors | Outdoor Recreation | Cycling | Kids' Bikes & Accessories | Kids' Bikes\n</CATEGORY>\n"}, {'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x413 at 0x7B7250181790>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Conquer the streets with the Razor Agitator BMX Bike! This 20-inch freestyle bike is built for young riders ready to take on any challenge. Durable frame, responsive handling – perfect for tricks and cruising. Get yours today!'}]}]
使用 TRL 和 SFTTrainer 微调 Gemma
现在,您可以微调模型了。Hugging Face TRL SFTTrainer 可让您轻松监督微调开放 LLM。SFTTrainer 是 transformers 库中 Trainer 的子类,支持所有相同的功能,包括日志记录、评估和检查点,但添加了其他生活质量功能,包括:
- 数据集格式设置,包括对话和说明格式
- 仅针对补全进行训练,忽略提示
- 打包数据集以提高训练效率
- 参数高效微调 (PEFT) 支持,包括 QloRA
- 准备模型和分词器以进行对话微调(例如添加特殊令牌)
以下代码从 Hugging Face 加载 Gemma 模型和分词器,并初始化量化配置。
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-4-E2B" # @param ["google/gemma-4-E2B","google/gemma-4-E4B"] {"allow-input":true}
# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
# Define model init arguments
model_kwargs = dict(
dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["dtype"],
bnb_4bit_quant_storage=model_kwargs["dtype"],
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-4-E2B-it") # Load the Instruction Tokenizer to use the official Gemma template
config.json: 0.00B [00:00, ?B/s] model.safetensors: 0%| | 0.00/10.2G [00:00<?, ?B/s] Loading weights: 0%| | 0/2011 [00:00<?, ?it/s] generation_config.json: 0%| | 0.00/149 [00:00<?, ?B/s] processor_config.json: 0.00B [00:00, ?B/s] chat_template.jinja: 0.00B [00:00, ?B/s] config.json: 0.00B [00:00, ?B/s] tokenizer_config.json: 0.00B [00:00, ?B/s] tokenizer.json: 0%| | 0.00/32.2M [00:00<?, ?B/s]
SFTTrainer 支持与 peft 的内置集成,这使得使用 QLoRA 高效调优 LLM 变得简单。您只需创建 LoraConfig 并将其提供给训练器即可。
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=["lm_head", "embed_tokens"], # make sure to save the lm_head and embed_tokens as you train the special tokens
ensure_weight_tying=True,
)
在开始训练之前,您需要在 SFTConfig 中定义要使用的超参数,并定义自定义 collate_fn 来处理视觉处理。collate_fn 会将包含文本和图片的消息转换为模型可以理解的格式。
from trl import SFTConfig
args = SFTConfig(
output_dir="gemma-product-description", # directory to save and repository id
num_train_epochs=3, # number of training epochs
per_device_train_batch_size=1, # batch size per device during training
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=5, # log every 5 steps
save_strategy="epoch", # save checkpoint every epoch
eval_strategy="epoch", # evaluate checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="tensorboard", # report metrics to tensorboard
dataset_text_field="", # need a dummy field for collator
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
remove_unused_columns = False # important for collator
)
# Create a data collator to encode text and image pairs
def collate_fn(examples):
texts = []
images = []
for example in examples:
image_inputs = process_vision_info(example["messages"])
text = processor.apply_chat_template(
example["messages"], add_generation_prompt=False, tokenize=False
)
texts.append(text.strip())
images.append(image_inputs)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
labels = batch["input_ids"].clone()
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch
现在,您已拥有创建 SFTTrainer 以开始训练模型所需的所有智能模块。
from trl import SFTTrainer
# Create Trainer object
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset_train,
eval_dataset=dataset_test,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
)
调用 train() 方法开始训练。
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
# Save the final model again to the Hugging Face Hub
trainer.save_model()
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
在测试模型之前,请务必释放内存。
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
使用 QLoRA 时,您只需训练适配器,而无需训练完整模型。这意味着,在训练期间保存模型时,您只需保存适配器权重,而无需保存完整模型。如果您想保存完整模型,以便更轻松地将其与 vLLM 或 TGI 等服务堆栈搭配使用,则可以使用 merge_and_unload 方法将适配器权重合并到模型权重中,然后使用 save_pretrained 方法保存模型。这会保存一个默认模型,该模型可用于推理。
from peft import PeftModel
# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
Loading weights: 0%| | 0/2011 [00:00<?, ?it/s] Writing model shards: 0%| | 0/5 [00:00<?, ?it/s] ['merged_model/processor_config.json']
测试模型推理并生成商品说明
训练完成后,您需要评估和测试模型。您可以从测试数据集中加载不同的样本,并针对这些样本评估模型。
model_id = "merged_model"
# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
Loading weights: 0%| | 0/2012 [00:00<?, ?it/s] The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.
您可以通过提供商品名称、类别和图片来测试推理。sample 包含一个漫威动作人偶。
import requests
from PIL import Image
# Test sample with Product Name, Category and Image
sample = {
"product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
"category": "Toys & Games | Toy Figures & Playsets | Action Figures",
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
}
def generate_description(sample, model, processor):
# Convert sample into messages and then apply the chat template
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": [
{"type": "image","image": sample["image"]},
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
]},
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(text)
# Process the image and text
image_inputs = process_vision_info(messages)
# Tokenize the text and process the images
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
# Move the inputs to the device
inputs = inputs.to(model.device)
# Generate the output
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<turn|>")]
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
# Trim the generation and decode the output to text
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
# generate the description
description = generate_description(sample, model, processor)
print("MODEL OUTPUT>> \n")
print(description)
<bos><|turn>system You are an expert product description writer for Amazon.<turn|> <|turn>user <|image|> Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image. Only return description. The description should be SEO optimized and for a better mobile search experience. <PRODUCT> Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur </PRODUCT> <CATEGORY> Toys & Games | Toy Figures & Playsets | Action Figures </CATEGORY><turn|> <|turn>model MODEL OUTPUT>> Enhance your collection with the Marvel Avengers - Avengers Assemble Ultron-Comforter Set! This soft and cuddly blanket and pillowcase feature everyone's favorite Avengers, Iron Man, and his loyal companion War Machine. Officially licensed by Marvel. Bring home the heroic team!
总结与后续步骤
本教程介绍了如何使用 TRL 和 QLoRA 微调 Gemma 模型以完成视觉任务,特别是生成商品说明。接下来,请查看以下文档:
- 了解如何使用 Gemma 模型生成文本。
- 了解如何使用 Hugging Face Transformers 微调 Gemma 以完成文本任务。
- 了解如何使用 Hugging Face Transformers 进行 完整模型微调。
- 了解如何在 Gemma 模型上执行 分布式微调和推理。
- 了解如何将 Gemma 开放模型与 Vertex AI 搭配使用。
- 了解如何使用 KerasNLP 微调 Gemma 并部署到 Vertex AI。
在 Google Colab 中运行
在 GitHub 上查看源代码