使用 Hugging Face Transformers 和 QloRA 微调 Gemma

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Kaggle 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

本指南将引导您了解如何使用 Hugging Face TransformersTRL 在自定义文本转 SQL 数据集上微调 Gemma。您会了解到以下内容:

  • 什么是量化低秩适应 (QLoRA)
  • 设置开发环境
  • 创建和准备微调数据集
  • 使用 TRL 和 SFTTrainer 微调 Gemma
  • 测试模型推理并生成 SQL 查询

什么是量化低秩适应 (QLoRA)

本指南演示了如何使用量化低秩适应 (QLoRA)。QLoRA 是一种高效微调 LLM 的热门方法,可在保持高性能的同时降低计算资源要求。在 QLoRA 中,预训练模型量化为 4 位,权重被冻结。然后,附加可训练的适配器层 (LoRA),并仅训练适配器层。之后,可以将适配器权重与基础模型合并,也可以将其保留为单独的适配器。

设置开发环境

第一步是安装 Hugging Face 库(包括 TRL)和数据集,以对开放模型进行微调,包括不同的 RLHF 和对齐技术。

# Install Pytorch & other libraries
%pip install torch tensorboard

# Install Transformers
%pip install transformers

# Install Hugging Face libraries
%pip install datasets accelerate evaluate bitsandbytes trl peft protobuf sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#%pip install flash-attn

注意:如果您使用的是 Ampere 架构(例如 NVIDIA L4)或更新的 GPU,则可以使用 Flash attention。Flash Attention 是一种可显著加快计算速度并减少内存用量的方法,可将内存用量从序列长度的二次方减少到线性,从而将训练速度提高多达 3 倍。如需了解详情,请参阅 FlashAttention

您需要有效的 Hugging Face 令牌才能发布模型。如果您在 Google Colab 中运行,可以使用 Colab 密钥安全地使用您的 Hugging Face 令牌;否则,您可以直接在 login 方法中设置令牌。请确保您的令牌也具有写入权限,因为您会在训练期间将模型推送到 Hub。

# Login into Hugging Face Hub
from huggingface_hub import login
login()

创建和准备微调数据集

对 LLM 进行微调时,请务必了解您的应用场景以及要解决的任务。这有助于您创建用于微调模型的数据集。如果您尚未确定自己的使用情形,不妨重新考虑一下。

举例来说,本指南重点介绍以下使用场景:

  • 对自然语言转 SQL 模型进行微调,以便无缝集成到数据分析工具中。其目标是大幅减少生成 SQL 查询所需的时间和专业知识,使即使是非技术用户也能从数据中提取有意义的分析洞见。

文本转 SQL 可以成为微调 LLM 的一个很好的应用场景,因为它是一项复杂的任务,需要大量关于数据和 SQL 语言的(内部)知识。

确定微调是合适的解决方案后,您需要一个数据集来进行微调。数据集应包含您要解决的任务的各种演示。您可以通过多种方式创建此类数据集,包括:

  • 使用现有的开源数据集,例如 Spider
  • 使用 LLM(例如 Alpaca)创建的合成数据集
  • 使用人工创建的数据集,例如 Dolly
  • 结合使用多种方法,例如 Orca

每种方法都有自己的优点和缺点,具体取决于预算、时间和质量要求。例如,使用现有数据集是最简单的方法,但可能无法根据您的特定使用情形进行调整;而使用领域专家可能最准确,但可能非常耗时且成本高昂。还可以组合使用多种方法来创建指令数据集,如 Orca:从 GPT-4 的复杂解释轨迹中逐步学习一文所示。

本指南使用一个现有的数据集 (philschmid/gretel-synthetic-text-to-sql),这是一个高质量的合成 Text-to-SQL 数据集,包含自然语言指令、架构定义、推理和相应的 SQL 查询。

Hugging Face TRL 支持自动将对话数据集格式转换为模板。这意味着,您只需将数据集转换为正确的 JSON 对象,trl 就会负责模板化并将其转换为正确的格式。

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

philschmid/gretel-synthetic-text-to-sql 包含超过 10 万个样本。为了减小指南的大小,我们对其进行了下采样,仅使用 1 万个样本。

您现在可以使用 Hugging Face Datasets 库加载数据集,并创建提示模板来组合自然语言指令、架构定义,以及为助理添加系统消息。

from datasets import load_dataset

# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 80% training samples and 20% test samples
dataset = dataset.train_test_split(test_size=0.2)

# Print formatted user prompt
for item in dataset["train"][0]["messages"]:
  print(item)
README.md:   0%|          | 0.00/737 [00:00<?, ?B/s]
synthetic_text_to_sql_train.snappy.parqu(…):   0%|          | 0.00/32.4M [00:00<?, ?B/s]
synthetic_text_to_sql_test.snappy.parque(…):   0%|          | 0.00/1.90M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]
Generating test split:   0%|          | 0/5851 [00:00<?, ? examples/s]
Map:   0%|          | 0/12500 [00:00<?, ? examples/s]
{'content': 'You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.', 'role': 'system'}
{'content': "Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\nCREATE TABLE Menu (id INT PRIMARY KEY, name VARCHAR(255), category VARCHAR(255), price DECIMAL(5,2));\n</SCHEMA>\n\n<USER_QUERY>\nCalculate the average price of all menu items in the Vegan category\n</USER_QUERY>\n", 'role': 'user'}
{'content': "SELECT AVG(price) FROM Menu WHERE category = 'Vegan';", 'role': 'assistant'}

使用 TRL 和 SFTTrainer 微调 Gemma

现在,您可以对模型进行微调了。Hugging Face TRL SFTTrainer 可让您轻松地监督微调开放式 LLM。SFTTrainertransformers 库中 Trainer 的子类,支持所有相同的功能,包括日志记录、评估和检查点设置,但还添加了其他使用体验功能,包括:

  • 数据集格式设置,包括对话格式和指令格式
  • 仅针对补全内容进行训练,忽略提示
  • 打包数据集以提高训练效率
  • 支持参数高效微调 (PEFT),包括 QloRA
  • 准备模型和分词器以进行对话式微调(例如添加特殊标记)

以下代码从 Hugging Face 加载 Gemma 模型和分词器,并初始化量化配置。

import torch
from transformers import AutoTokenizer, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-4-E2B" # @param ["google/gemma-4-E2B","google/gemma-4-E4B"] {"allow-input":true}

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    dtype=torch_dtype,
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['dtype'],
    bnb_4bit_quant_storage=model_kwargs['dtype'],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-E2B-it") # Load the Instruction Tokenizer to use the official Gemma template
config.json: 0.00B [00:00, ?B/s]
model.safetensors:   0%|          | 0.00/10.2G [00:00<?, ?B/s]
Loading weights:   0%|          | 0/2011 [00:00<?, ?it/s]
generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]
config.json: 0.00B [00:00, ?B/s]
tokenizer_config.json: 0.00B [00:00, ?B/s]
tokenizer.json:   0%|          | 0.00/32.2M [00:00<?, ?B/s]
chat_template.jinja: 0.00B [00:00, ?B/s]

SFTTrainer 支持与 peft 的内置集成,可让您轻松高效地使用 QLoRA 调优 LLM。您只需创建 LoraConfig 并将其提供给训练器。

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"], # make sure to save the lm_head and embed_tokens as you train the special tokens
    ensure_weight_tying=True,
)

在开始训练之前,您需要在 SFTConfig 实例中定义要使用的超参数。

import torch
from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_length=512,                         # max length for model and packing of the dataset
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    eval_strategy="epoch",                  # evaluate checkpoint every epoch
    learning_rate=5e-5,                     # learning rate
    fp16=True if model.dtype == torch.float16 else False,   # use float16 precision
    bf16=True if model.dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # Template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

现在,您已拥有创建 SFTTrainer 所需的全部智能模块,可以开始训练模型了。

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    processing_class=tokenizer,
)
Tokenizing train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]
Tokenizing eval dataset:   0%|          | 0/2500 [00:00<?, ? examples/s]

通过调用 train() 方法开始训练。

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.

在测试模型之前,请务必释放内存。

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

使用 QLoRA 时,您只需训练适配器,而无需训练整个模型。这意味着,在训练期间保存模型时,您只需保存适配器权重,而无需保存整个模型。如果您想保存完整模型,以便更轻松地将其与 vLLM 或 TGI 等服务堆栈搭配使用,可以使用 merge_and_unload 方法将适配器权重合并到模型权重中,然后使用 save_pretrained 方法保存模型。这样会保存一个可用于推理的默认模型。

from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
Loading weights:   0%|          | 0/2011 [00:00<?, ?it/s]
Writing model shards:   0%|          | 0/5 [00:00<?, ?it/s]
('merged_model/tokenizer_config.json',
 'merged_model/chat_template.jinja',
 'merged_model/tokenizer.json')

测试模型推理并生成 SQL 查询

训练完成后,您需要评估和测试模型。您可以从测试数据集中加载不同的样本,并根据这些样本评估模型。

import torch
from transformers import pipeline

model_id = "merged_model"

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  model_id,
  device_map="auto",
  dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Loading weights:   0%|          | 0/2012 [00:00<?, ?it/s]
The tied weights mapping and config for this model specifies to tie model.language_model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints with different values, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning.

我们来从测试数据集中加载一个随机样本,并生成一个 SQL 命令。

from random import randint
import re
from transformers import pipeline, GenerationConfig

config = GenerationConfig.from_pretrained(model_id)
config.max_new_tokens = 256

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)
print(prompt)

# Generate our SQL query.
outputs = pipe(prompt, generation_config=config)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][1]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][1]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
<bos><|turn>system
You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.<turn|>
<|turn>user
Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
CREATE TABLE broadband_plans (plan_id INT, plan_name VARCHAR(255), download_speed INT, upload_speed INT, price DECIMAL(5,2));
</SCHEMA>

<USER_QUERY>
Delete a broadband plan from the 'broadband_plans' table
</USER_QUERY><turn|>
<|turn>model

Context:
 CREATE TABLE broadband_plans (plan_id INT, plan_name VARCHAR(255), download_speed INT, upload_speed INT, price DECIMAL(5,2));
Query:
 Delete a broadband plan from the 'broadband_plans' table
Original Answer:
DELETE FROM broadband_plans WHERE plan_id = 3001;
Generated Answer:
DELETE FROM broadband_plans
WHERE plan_name = 'Basic';

总结与后续步骤

本教程介绍了如何使用 TRL 和 QLoRA 微调 Gemma 模型。接下来,请参阅以下文档: