本指南将详细介绍如何使用 Hugging Face Transformers 和 TRL 对自定义文本转 SQL 数据集进行微调 Gemma。您会了解到以下内容:
- 什么是量化低秩自适应 (QLoRA)
- 设置开发环境
- 创建和准备微调数据集
- 使用 TRL 和 SFTTrainer 微调 Gemma
- 测试模型推理并生成 SQL 查询
什么是量化低秩自适应 (QLoRA)
本指南演示了如何使用量化低秩自适应 (QLoRA)。由于它可以减少计算资源需求,同时保持高性能,因此已成为高效微调 LLM 的热门方法。在 QloRA 中,预训练模型会量化为 4 位,并且权重会被冻结。然后,附加可训练的适配器层 (LoRA),并且仅训练适配器层。之后,适配器权重可以与基准模型合并,也可以保留为单独的适配器。
设置开发环境
第一步是安装 Hugging Face 库(包括 TRL)和数据集,以微调开放式模型,包括不同的 RLHF 和对齐技术。
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard
# Install Gemma release branch from Hugging Face
%pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
# Install Hugging Face libraries
%pip install --upgrade \
"datasets==3.3.2" \
"accelerate==1.4.0" \
"evaluate==0.4.3" \
"bitsandbytes==0.45.3" \
"trl==0.15.2" \
"peft==0.14.0" \
protobuf \
sentencepiece
# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn
注意:如果您使用的是采用 Ampere 架构的 GPU(例如 NVIDIA L4)或更新型号的 GPU,则可以使用 Flash 注意力机制。Flash Attention 是一种方法,可显著加快计算速度,并将内存用量从序列长度的二次方减少到线性,从而将训练速度提高多达 3 倍。如需了解详情,请参阅 FlashAttention。
在开始训练之前,您必须确保已接受 Gemma 的使用条款。您可以在 Hugging Face 上接受许可,只需点击模型页面上的“同意并访问代码库”按钮即可:http://huggingface.co/google/gemma-3-1b-pt
接受许可后,您需要有效的 Hugging Face 令牌才能访问该模型。如果您是在 Google Colab 中运行,则可以使用 Colab Secret 安全地使用 Hugging Face 令牌;否则,您可以直接在 login
方法中设置令牌。由于您会在训练期间将模型推送到 Hub,因此请确保您的令牌也具有写入权限。
from google.colab import userdata
from huggingface_hub import login
# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)
创建和准备微调数据集
在微调 LLM 时,请务必了解您的应用场景和要解决的任务。这有助于您创建数据集来微调模型。如果您尚未定义用例,不妨回过头来重新思考。
例如,本指南重点介绍以下使用场景:
- 微调自然语言转换为 SQL 模型,以便无缝集成到数据分析工具中。其目标是显著缩短生成 SQL 查询所需的时间和专业知识,让非技术用户也能从数据中提取有意义的数据洞见。
文本转 SQL 是一个复杂的任务,需要掌握大量有关数据和 SQL 语言的(内部)知识,因此非常适合用于微调 LLM。
确定微调是合适的解决方案后,您需要一个数据集来进行微调。数据集应包含一系列示例,这些示例展示了您要解决的任务。您可以通过多种方式创建此类数据集,包括:
每种方法都有各自的优点和缺点,具体取决于预算、时间和质量要求。例如,使用现有数据集是最简单的方法,但可能无法量身定制您的具体用例;而使用领域专家可能最准确,但可能既费时又费钱。您还可以结合使用多种方法来创建指令数据集,如 Orca:从 GPT-4 的复杂说明轨迹中进行渐进学习中所示。
本指南使用现有数据集 (philschmid/gretel-synthetic-text-to-sql),这是一个高质量的合成文本到 SQL 数据集,其中包含自然语言说明、架构定义、推理和相应的 SQL 查询。
Hugging Face TRL 支持对话数据集格式的自动模板化。这意味着,您只需将数据集转换为正确的 JSON 对象,trl
便会负责模板化并将其转换为正确的格式。
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
philschmid/gretel-synthetic-text-to-sql 包含超过 10 万个示例。为了使指南大小不变,我们将其下采样为仅使用 10,000 个样本。
现在,您可以使用 Hugging Face Datasets 库加载数据集,并创建提示模板来组合自然语言指令、架构定义,并为您的助理添加系统消息。
from datasets import load_dataset
# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""
# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.
<SCHEMA>
{context}
</SCHEMA>
<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
return {
"messages": [
# {"role": "system", "content": system_message},
{"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
{"role": "assistant", "content": sample["sql"]}
]
}
# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))
# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)
# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])
使用 TRL 和 SFTTrainer 微调 Gemma
现在,您可以对模型进行微调了。借助 Hugging Face TRL SFTTrainer,您可以轻松监督微调开放式 LLM。SFTTrainer
是 transformers
库中的 Trainer
的子类,支持所有相同的功能(包括日志记录、评估和检查点),但还添加了其他实用功能,包括:
- 数据集格式设置,包括对话格式和指令格式
- 仅根据完成情况进行训练,忽略提示
- 打包数据集以提高训练效率
- 支持参数高效微调 (PEFT),包括 QloRA
- 准备模型和分词器以进行对话式微调(例如添加特殊标记)
以下代码会从 Hugging Face 加载 Gemma 模型和分词器,并初始化量化配置。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`
# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
model_class = AutoModelForCausalLM
else:
model_class = AutoModelForImageTextToText
# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16
# Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)
# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template
SFTTrainer
支持与 peft
进行原生集成,这样您就可以轻松使用 QLoRA 高效地调优 LLM。您只需创建一个 LoraConfig
并将其提供给培训师即可。
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)
在开始训练之前,您需要定义要在 SFTConfig
实例中使用的超参数。
from trl import SFTConfig
args = SFTConfig(
output_dir="gemma-text-to-sql", # directory to save and repository id
max_seq_length=512, # max sequence length for model and packing of the dataset
packing=True, # Groups multiple samples in the dataset into a single sequence
num_train_epochs=3, # number of training epochs
per_device_train_batch_size=1, # batch size per device during training
gradient_accumulation_steps=4, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=10, # log every 10 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
fp16=True if torch_dtype == torch.float16 else False, # use float16 precision
bf16=True if torch_dtype == torch.bfloat16 else False, # use bfloat16 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="tensorboard", # report metrics to tensorboard
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": True, # Add EOS token as separator token between examples
}
)
现在,您已拥有创建 SFTTrainer
所需的所有构建块,可以开始训练模型了。
from trl import SFTTrainer
# Create Trainer object
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
peft_config=peft_config,
processing_class=tokenizer
)
通过调用 train()
方法开始训练。
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
# Save the final model again to the Hugging Face Hub
trainer.save_model()
在测试模型之前,请务必释放内存。
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
使用 QLoRA 时,您只需训练适配器,而无需训练完整模型。这意味着,在训练期间保存模型时,您只会保存适配器权重,而不会保存完整模型。如果您想保存完整模型,以便更轻松地与 vLLM 或 TGI 等服务堆栈搭配使用,可以使用 merge_and_unload
方法将适配器权重合并到模型权重,然后使用 save_pretrained
方法保存模型。这会保存一个默认模型,该模型可用于推理。
from peft import PeftModel
# Load Model base model
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
测试模型推理并生成 SQL 查询
训练完成后,您需要评估和测试模型。您可以从测试数据集中加载不同的样本,并针对这些样本评估模型。
import torch
from transformers import pipeline
model_id = "gemma-text-to-sql"
# Load Model with PEFT adapter
model = model_class.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch_dtype,
attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
我们从测试数据集中加载一个随机样本,并生成一个 SQL 命令。
from random import randint
import re
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]
# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)
# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)
# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
总结和后续步骤
本教程介绍了如何使用 TRL 和 QLoRA 微调 Gemma 模型。接下来,请参阅以下文档: