Fine-Tune Gemma for Vision Tasks using Hugging Face Transformers and QLoRA

View on ai.google.dev Run in Google Colab Run in Kaggle Open in Vertex AI View source on GitHub

This guide walks you through how to fine-tune Gemma on a custom image and text dataset for a vision task (generating product descriptions) using Hugging Face Transformers and TRL. You will learn:

  • What is Quantized Low-Rank Adaptation (QLoRA)
  • Setup development environment
  • Create and prepare the fine-tuning dataset
  • Fine-tune Gemma using TRL and the SFTTrainer
  • Test Model Inference and generate product descriptions from images and text.

What is Quantized Low-Rank Adaptation (QLoRA)

This guide demonstrates the use of Quantized Low-Rank Adaptation (QLoRA), which emerged as a popular method to efficiently fine-tune LLMs as it reduces computational resource requirements while maintaining high performance. In QloRA, the pretrained model is quantized to 4-bit and the weights are frozen. Then trainable adapter layers (LoRA) are attached and only the adapter layers are trained. Afterwards, the adapter weights can be merged with the base model or kept as a separate adapter.

Setup development environment

The first step is to install Hugging Face Libraries, including TRL, and datasets to fine-tune open model.

# Install Pytorch & other libraries
%pip install torch tensorboard torchvision

# Install Transformers
%pip install transformers

# Install Hugging Face libraries
%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf pillow sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#%pip install flash-attn

Note: If you are using a GPU with Ampere architecture (such as NVIDIA L4) or newer, you can use Flash attention. Flash Attention is a method that significantly speeds computations up and reduces memory usage from quadratic to linear in sequence length, leading to acelerating training up to 3x. Learn more at FlashAttention.

You need a valid Hugging Face Token to publish your model. If you are running inside a Google Colab, you can securely use your Hugging Face Token using the Colab secrets otherwise you can set the token as directly in the login method. Make sure your token has write access too, as you push your model to the Hub during training.

# Login into Hugging Face Hub
from huggingface_hub import login
login()

Create and prepare the fine-tuning dataset

When fine-tuning LLMs, it is important to know your use case and the task you want to solve. This helps you create a dataset to fine-tune your model. If you haven't defined your use case yet, you might want to go back to the drawing board.

As an example, this guide focuses on the following use case:

  • Fine-tuning a Gemma model to generate concise, SEO-optimized product descriptions for an ecommerce platform, specifically tailored for mobile search.

This guide uses the philschmid/amazon-product-descriptions-vlm dataset, a dataset of Amazon product descriptions, including product images and categories.

Hugging Face TRL supports multimodal conversations. The important piece is the "image" role, which tells the processing class that it should load the image. The structure should follow:

{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}
{"messages": [{"role": "system", "content": [{"type": "text", "text":"You are..."}]}, {"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "..."}]}]}

You can now use the Hugging Face Datasets library to load the dataset and create a prompt template to combine the image, product name, and category, and add a system message. The dataset includes images asPil.Image objects.

from datasets import load_dataset
from PIL import Image

# System message for the assistant
system_message = "You are an expert product description writer for Amazon."

# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
{product}
</PRODUCT>

<CATEGORY>
{category}
</CATEGORY>
"""

# Convert dataset to OAI messages
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                #"content": [{"type": "text", "text": system_message}],
                "content": system_message,
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_prompt.format(
                            product=sample["Product Name"],
                            category=sample["Category"],
                        ),
                    },
                    {
                        "type": "image",
                        "image": sample["image"],
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["description"]}],
            },
        ],
    }

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

# Load dataset from the hub
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")
dataset = dataset.train_test_split(test_size=0.1)

# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset_train = [format_data(sample) for sample in dataset["train"]]
dataset_test = [format_data(sample) for sample in dataset["test"]]

print(dataset_train[345]["messages"])
README.md: 0.00B [00:00, ?B/s]
data/train-00000-of-00001.parquet:   0%|          | 0.00/47.6M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/1345 [00:00<?, ? examples/s]
[{'role': 'system', 'content': 'You are an expert product description writer for Amazon.'}, {'role': 'user', 'content': [{'type': 'text', 'text': "Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.\nOnly return description. The description should be SEO optimized and for a better mobile search experience.\n\n<PRODUCT>\nRazor Agitator BMX/Freestyle Bike, 20-Inch\n</PRODUCT>\n\n<CATEGORY>\nSports & Outdoors | Outdoor Recreation | Cycling | Kids' Bikes & Accessories | Kids' Bikes\n</CATEGORY>\n"}, {'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x413 at 0x7B7250181790>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'Conquer the streets with the Razor Agitator BMX Bike! This 20-inch freestyle bike is built for young riders ready to take on any challenge. Durable frame, responsive handling – perfect for tricks and cruising.  Get yours today!'}]}]

Fine-tune Gemma using TRL and the SFTTrainer

You are now ready to fine-tune your model. Hugging Face TRL SFTTrainer makes it straightforward to supervise fine-tune open LLMs. The SFTTrainer is a subclass of the Trainer from the transformers library and supports all the same features, including logging, evaluation, and checkpointing, but adds additional quality of life features, including:

  • Dataset formatting, including conversational and instruction formats
  • Training on completions only, ignoring prompts
  • Packing datasets for more efficient training
  • Parameter-efficient fine-tuning (PEFT) support including QloRA
  • Preparing the model and tokenizer for conversational fine-tuning (such as adding special tokens)

The following code loads the Gemma model and tokenizer from Hugging Face and initializes the quantization configuration.

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-4-E2B" # @param ["google/gemma-4-E2B","google/gemma-4-E4B"] {"allow-input":true}

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["dtype"],
    bnb_4bit_quant_storage=model_kwargs["dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-4-E2B-it") # Load the Instruction Tokenizer to use the official Gemma template
config.json: 0.00B [00:00, ?B/s]
model.safetensors:   0%|          | 0.00/10.2G [00:00<?, ?B/s]
Loading weights:   0%|          | 0/2011 [00:00<?, ?it/s]
generation_config.json:   0%|          | 0.00/149 [00:00<?, ?B/s]
processor_config.json: 0.00B [00:00, ?B/s]
chat_template.jinja: 0.00B [00:00, ?B/s]
config.json: 0.00B [00:00, ?B/s]
tokenizer_config.json: 0.00B [00:00, ?B/s]
tokenizer.json:   0%|          | 0.00/32.2M [00:00<?, ?B/s]

The SFTTrainer supports a built-in integration with peft, which makes it straightforward to efficiently tune LLMs using QLoRA. You only need to create a LoraConfig and provide it to the trainer.

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"], # make sure to save the lm_head and embed_tokens as you train the special tokens
    ensure_weight_tying=True,
)

Before you can start your training, you need to define the hyperparameter you want to use in a SFTConfig and a custom collate_fn to handle the vision processing. The collate_fn converts the messages with text and images into a format that the model can understand.

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-product-description",     # directory to save and repository id
    num_train_epochs=3,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    eval_strategy="epoch",                      # evaluate checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
    remove_unused_columns = False               # important for collator
)

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == processor.tokenizer.boi_token_id] = -100
    labels[labels == processor.tokenizer.image_token_id] = -100
    labels[labels == processor.tokenizer.eoi_token_id] = -100

    batch["labels"] = labels
    return batch

You now have every building block you need to create your SFTTrainer to start the training of your model.

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset_train,
    eval_dataset=dataset_test,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

Start training by calling the train() method.

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.

Before you can test your model, make sure to free the memory.

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

When using QLoRA, you only train adapters and not the full model. This means when saving the model during training you only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with serving stacks like vLLM or TGI, you can merge the adapter weights into the model weights using the merge_and_unload method and then save the model with the save_pretrained method. This saves a default model, which can be used for inference.

from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
Loading weights:   0%|          | 0/2011 [00:00<?, ?it/s]
Writing model shards:   0%|          | 0/5 [00:00<?, ?it/s]
['merged_model/processor_config.json']

Test Model Inference and generate product descriptions

After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.

model_id = "merged_model"

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  model_id,
  device_map="auto",
  dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
Loading weights:   0%|          | 0/2012 [00:00<?, ?it/s]
The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.

You can test inference by providing a product name, category and image. The sample includes a marvel action figure.

import requests
from PIL import Image

# Test sample with Product Name, Category and Image
sample = {
  "product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
  "category": "Toys & Games | Toy Figures & Playsets | Action Figures",
  "image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
}

def generate_description(sample, model, processor):
    # Convert sample into messages and then apply the chat template
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": [
            {"type": "image","image": sample["image"]},
            {"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
        ]},
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    print(text)
    # Process the image and text
    image_inputs = process_vision_info(messages)
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    # Move the inputs to the device
    inputs = inputs.to(model.device)

    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<turn|>")]
    generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]

# generate the description
description = generate_description(sample, model, processor)
print("MODEL OUTPUT>> \n")
print(description)
<bos><|turn>system
You are an expert product description writer for Amazon.<turn|>
<|turn>user


<|image|>

Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur
</PRODUCT>

<CATEGORY>
Toys & Games | Toy Figures & Playsets | Action Figures
</CATEGORY><turn|>
<|turn>model

MODEL OUTPUT>> 

Enhance your collection with the Marvel Avengers - Avengers Assemble Ultron-Comforter Set! This soft and cuddly blanket and pillowcase feature everyone's favorite Avengers, Iron Man, and his loyal companion War Machine. Officially licensed by Marvel.  Bring home the heroic team!

Summary and next steps

This tutorial covered how to fine-tune a Gemma model for vision tasks using TRL and QLoRA, specifically for generating product descriptions. Check out the following docs next: