使用 Hugging Face Transformers 和 QloRA 微調 Gemma

本指南將逐步說明如何使用 Hugging Face TransformersTRL,針對自訂文字轉 SQL 資料集微調 Gemma。您將學會:

  • 量化低秩調整 (QLoRA) 是什麼
  • 設定開發環境
  • 建立及準備精修資料集
  • 使用 TRL 和 SFTTrainer 微調 Gemma
  • 測試模型推論並產生 SQL 查詢

量化低秩調整 (QLoRA) 是什麼

本指南將說明如何使用量化低秩序調整 (QLoRA),這是一種有效精細調整 LLM 的熱門方法,因為它可減少運算資源需求,同時維持高效能。在 QloRA 中,預先訓練的模型會量化為 4 位元,權重則會凍結。接著,系統會附加可訓練的轉接層 (LoRA),並只訓練轉接層。之後,轉接器權重可與基礎模型合併,或保留為獨立的轉接器。

設定開發環境

第一步是安裝 Hugging Face 程式庫 (包括 TRL) 和資料集,以便微調開放式模型,包括不同的 RLHF 和對齊技術。

# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  protobuf \
  sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn

注意:如果您使用的是搭載 Ampere 架構 (例如 NVIDIA L4) 或更新版本的 GPU,可以使用 Flash attention。Flash Attention 是一種方法,可大幅加快運算速度,並將序列長度從二次方減少到線性,進而將訓練速度加快至 3 倍。詳情請參閱 FlashAttention

請務必先接受 Gemma 的使用條款,才能開始訓練。您可以接受 Hugging Face 的授權,方法是點選模型頁面上的「同意並存取存放區」按鈕:http://huggingface.co/google/gemma-3-1b-pt

接受授權後,您必須使用有效的 Hugging Face 權杖才能存取模型。如果您在 Google Colab 中執行,可以使用 Colab 機密資料安全地使用 Hugging Face 權杖,否則您可以直接在 login 方法中設定權杖。請確認您的權杖也有寫入權限,因為您會在訓練期間將模型推送至 Hub。

from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab 
login(hf_token)

建立及準備精修資料集

在微調 LLM 時,請務必瞭解您的用途和要解決的任務。這有助於您建立資料集,以便微調模型。如果您尚未定義用途,建議您重新規劃。

本指南會以以下用途為例:

  • 微調將自然語言轉為 SQL 的模型,以便順暢整合至資料分析工具。這項工具的目標是大幅減少產生 SQL 查詢所需的時間和專業知識,讓非技術人員也能從資料中擷取有意義的洞察資料。

文字轉 SQL 是微調 LLM 的絕佳用途,因為這項複雜工作需要大量的資料和 SQL 語言 (內部) 知識。

確定微調是適當的解決方案後,您需要資料集才能進行微調。資料集應包含多種示例,展示您要解決的工作。建立這類資料集的方式有很多種,包括:

  • 使用現有的開放原始碼資料集,例如 Spider
  • 使用大型語言模型 (例如 Alpaca) 建立的合成資料集
  • 使用人類建立的資料集,例如 Dolly
  • 使用多種方法 (例如 Orca) 進行組合

每種方法各有優缺,取決於預算、時間和品質要求。舉例來說,使用現有資料集最簡單,但可能無法針對您的特定用途進行調整;而聘請領域專家可能最準確,但可能耗時且費用高昂。您也可以結合多種方法來建立指令資料集,如 Orca:透過 GPT-4 的複雜說明追蹤記錄進行漸進式學習 一文所述。

本指南使用現有資料集 (philschmid/gretel-synthetic-text-to-sql),這是高品質的模擬文字轉 SQL 資料集,包含自然語言指示、結構定義、推理和對應的 SQL 查詢。

Hugging Face TRL 支援對話資料集格式的自動範本建立功能。也就是說,您只需將資料集轉換為正確的 JSON 物件,trl 就會負責建立範本並將其轉換為正確的格式。

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

philschmid/gretel-synthetic-text-to-sql 包含超過 10 萬個樣本。為了讓指南保持小巧,我們將其降樣至只使用 10,000 個樣本。

你現在可以使用 Hugging Face Datasets 程式庫載入資料集,並建立提示範本,結合自然語言指示、結構定義,以及為助理新增系統訊息。

from datasets import load_dataset

# System message for the assistant 
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }  

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])

使用 TRL 和 SFTTrainer 對 Gemma 進行微調

您現在可以微調模型了。只要使用 Hugging Face TRL 的 SFTTrainer,就能輕鬆監督開放式 LLM 的微調作業。SFTTrainertransformers 程式庫中 Trainer 的子類別,支援所有相同的功能,包括記錄、評估和檢查點,但新增了其他便利功能,包括:

  • 資料集格式,包括對話和指示格式
  • 只訓練完成動作,忽略提示
  • 壓縮資料集以提高訓練效率
  • 高效參數微調 (PEFT) 支援功能,包括 QLoRA
  • 準備對話微調的模型和代碼化工具 (例如新增特殊符記)

下列程式碼會從 Hugging Face 載入 Gemma 模型和分析器,並初始化量化設定。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

SFTTrainer 支援與 peft 的原生整合,讓您能輕鬆使用 QLoRA 有效調整 LLM。您只需建立 LoraConfig 並提供給訓練工具即可。

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

您必須先定義要在 SFTConfig 執行個體中使用的超參數,才能開始訓練。

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_seq_length=512,                     # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

您現在已擁有建立 SFTTrainer 所需的所有建構區塊,可以開始訓練模型了。

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

呼叫 train() 方法開始訓練。

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

請務必釋放記憶體,才能測試模型。

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

使用 QLoRA 時,您只需訓練轉接器,而非完整模型。也就是說,在訓練期間儲存模型時,您只會儲存適應器權重,而非完整模型。如果您想儲存完整模型,以便搭配 vLLM 或 TGI 等服務堆疊使用,可以使用 merge_and_unload 方法將轉接器權重合併至模型權重,然後使用 save_pretrained 方法儲存模型。這會儲存可用於推論的預設模型。

from peft import PeftModel

# Load Model base model
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

測試模型推論並產生 SQL 查詢

訓練完成後,您需要評估及測試模型。您可以從測試資料集中載入不同的樣本,並針對這些樣本評估模型。

import torch
from transformers import pipeline

model_id = "gemma-text-to-sql"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch_dtype,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

我們來從測試資料集中載入隨機樣本,並產生 SQL 指令。

from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

總結與後續步驟

本教學課程說明如何使用 TRL 和 QLoRA 微調 Gemma 模型。接著請查看下列文件: