ปรับแต่ง Gemma โดยใช้ Hugging Face Transformers และ QloRA

คู่มือนี้จะแนะนำวิธีปรับแต่ง Gemma ในชุดข้อมูลข้อความถึง SQL ที่กำหนดเองโดยใช้ Transformers และ TRL ของ Hugging Face คุณจะได้เรียนรู้:

  • Quantized Low-Rank Adaptation (QLoRA) คืออะไร
  • ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์
  • สร้างและเตรียมชุดข้อมูลการปรับแต่ง
  • ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer
  • ทดสอบการอนุมานโมเดลและสร้างการค้นหา SQL

Quantized Low-Rank Adaptation (QLoRA) คืออะไร

คู่มือนี้แสดงให้เห็นการใช้ Quantized Low-Rank Adaptation (QLoRA) ซึ่งกลายเป็นวิธีที่ได้รับความนิยมในการปรับแต่ง LLM อย่างมีประสิทธิภาพ เนื่องจากช่วยลดข้อกำหนดด้านทรัพยากรการคำนวณในขณะที่ยังคงประสิทธิภาพสูงไว้ได้ ใน QLoRA โมเดลที่ฝึกไว้ล่วงหน้าจะได้รับการควอนไทซ์เป็น 4 บิตและมีการตรึงน้ำหนัก จากนั้นจะมีการแนบเลเยอร์อะแดปเตอร์ที่ฝึกได้ (LoRA) และจะมีการฝึกเฉพาะเลเยอร์อะแดปเตอร์ หลังจากนั้น คุณจะผสานน้ำหนักของอะแดปเตอร์กับโมเดลพื้นฐานหรือเก็บไว้เป็นอะแดปเตอร์แยกต่างหากก็ได้

ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์

ขั้นตอนแรกคือการติดตั้งไลบรารี Hugging Face ซึ่งรวมถึง TRL และชุดข้อมูลเพื่อปรับแต่งโมเดลแบบเปิด ซึ่งรวมถึงเทคนิค RLHF และการจัดแนวต่างๆ

# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.21.0" \
  "peft==0.14.0" \
  protobuf \
  sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn

หมายเหตุ: หากใช้ GPU ที่มีสถาปัตยกรรม Ampere (เช่น NVIDIA L4) หรือใหม่กว่า คุณจะใช้ Flash Attention ได้ Flash Attention เป็นวิธีที่ช่วยเร่งการคำนวณได้อย่างมากและลดการใช้หน่วยความจำจากกำลังสองเป็นเชิงเส้นในความยาวของลำดับ ซึ่งช่วยเร่งการฝึกได้สูงสุด 3 เท่า ดูข้อมูลเพิ่มเติมได้ที่ FlashAttention

ก่อนเริ่มฝึก คุณต้องตรวจสอบว่าได้ยอมรับข้อกำหนดในการใช้งาน Gemma แล้ว คุณยอมรับใบอนุญาตได้ที่ Hugging Face โดยคลิกปุ่ม "ยอมรับและเข้าถึงที่เก็บ" ในหน้าโมเดลที่ http://huggingface.co/google/gemma-3-1b-pt

หลังจากยอมรับใบอนุญาตแล้ว คุณจะต้องมีโทเค็น Hugging Face ที่ถูกต้องเพื่อเข้าถึงโมเดล หากคุณกำลังเรียกใช้ภายใน Google Colab คุณสามารถใช้โทเค็น Hugging Face ได้อย่างปลอดภัยโดยใช้ความลับของ Colab หรือจะตั้งค่าโทเค็นโดยตรงในเมธอด login ก็ได้ ตรวจสอบว่าโทเค็นมีสิทธิ์เข้าถึงแบบเขียนด้วย เนื่องจากคุณจะพุชโมเดลไปยัง Hub ในระหว่างการฝึก

from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

สร้างและเตรียมชุดข้อมูลการปรับแต่ง

เมื่อปรับแต่ง LLM คุณควรทราบกรณีการใช้งานและงานที่ต้องการแก้ไข ซึ่งจะช่วยให้คุณสร้างชุดข้อมูลเพื่อปรับแต่งโมเดลได้ หากยังไม่ได้กำหนดกรณีการใช้งาน คุณอาจต้องกลับไปที่กระดานออกแบบ

ตัวอย่างเช่น คู่มือนี้มุ่งเน้นไปที่กรณีการใช้งานต่อไปนี้

  • ปรับแต่งโมเดลภาษาธรรมชาติเป็น SQL เพื่อผสานรวมเข้ากับเครื่องมือวิเคราะห์ข้อมูลได้อย่างราบรื่น เป้าหมายคือการลดเวลาและความเชี่ยวชาญที่จำเป็นสำหรับการสร้างคำค้นหา SQL อย่างมาก เพื่อให้แม้แต่ผู้ใช้ที่ไม่เชี่ยวชาญด้านเทคนิคก็สามารถดึงข้อมูลเชิงลึกที่มีความหมายจากข้อมูลได้

การแปลงข้อความเป็น SQL อาจเป็นกรณีการใช้งานที่ดีสำหรับการปรับแต่ง LLM เนื่องจากเป็นงานที่ซับซ้อนซึ่งต้องใช้ความรู้ (ภายใน) จำนวนมากเกี่ยวกับข้อมูลและภาษา SQL

เมื่อพิจารณาแล้วว่าการปรับแต่งเป็นโซลูชันที่เหมาะสม คุณจะต้องมีชุดข้อมูลเพื่อใช้ในการปรับแต่ง ชุดข้อมูลควรเป็นการสาธิตงานที่คุณต้องการแก้ไขที่หลากหลาย การสร้างชุดข้อมูลดังกล่าวทำได้หลายวิธี เช่น

  • การใช้ชุดข้อมูลโอเพนซอร์สที่มีอยู่ เช่น Spider
  • การใช้ชุดข้อมูลสังเคราะห์ที่สร้างโดย LLM เช่น Alpaca
  • การใช้ชุดข้อมูลที่สร้างขึ้นโดยมนุษย์ เช่น Dolly
  • การใช้วิธีการต่างๆ ร่วมกัน เช่น Orca

แต่ละวิธีมีข้อดีและข้อเสียแตกต่างกัน และขึ้นอยู่กับงบประมาณ เวลา และข้อกำหนดด้านคุณภาพ ตัวอย่างเช่น การใช้ชุดข้อมูลที่มีอยู่เป็นวิธีที่ง่ายที่สุด แต่อาจไม่เหมาะกับกรณีการใช้งานที่เฉพาะเจาะจงของคุณ ในขณะที่การใช้ผู้เชี่ยวชาญเฉพาะด้านอาจแม่นยำที่สุด แต่ก็อาจใช้เวลานานและมีค่าใช้จ่ายสูง นอกจากนี้ยังสามารถรวมหลายวิธีเพื่อสร้างชุดข้อมูลคำสั่งได้ ดังที่แสดงใน Orca: Progressive Learning from Complex Explanation Traces of GPT-4

คู่มือนี้ใช้ชุดข้อมูลที่มีอยู่แล้ว (philschmid/gretel-synthetic-text-to-sql) ซึ่งเป็นชุดข้อมูล Text-to-SQL สังเคราะห์คุณภาพสูงที่มีคำสั่งภาษามนุษย์ คำจำกัดความของสคีมา การให้เหตุผล และการค้นหา SQL ที่เกี่ยวข้อง

Hugging Face TRL รองรับการสร้างเทมเพลตชุดข้อมูลการสนทนารูปแบบต่างๆ โดยอัตโนมัติ ซึ่งหมายความว่าคุณเพียงแค่ต้องแปลงชุดข้อมูลเป็นออบเจ็กต์ JSON ที่ถูกต้อง และ trl จะจัดการการสร้างเทมเพลตและจัดรูปแบบให้ถูกต้อง

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

philschmid/gretel-synthetic-text-to-sql มีตัวอย่างมากกว่า 100,000 รายการ เราจะลดขนาดของไกด์ลงเพื่อให้ใช้ตัวอย่างเพียง 10,000 รายการ

ตอนนี้คุณสามารถใช้ไลบรารีชุดข้อมูลของ Hugging Face เพื่อโหลดชุดข้อมูลและสร้างเทมเพลตพรอมต์เพื่อรวมคำสั่งภาษาธรรมชาติ คำจำกัดความของสคีมา และเพิ่มข้อความระบบสำหรับผู้ช่วยได้แล้ว

from datasets import load_dataset

# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])

ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer

ตอนนี้คุณพร้อมที่จะปรับแต่งโมเดลแล้ว SFTTrainer ของ TRL จาก Hugging Face ช่วยให้การดูแลการปรับแต่ง LLM แบบเปิดเป็นเรื่องง่าย SFTTrainer เป็นคลาสย่อยของ Trainer จากไลบรารี transformers และรองรับฟีเจอร์ทั้งหมดเหมือนกัน ซึ่งรวมถึงการบันทึก การประเมิน และการตรวจสอบ แต่เพิ่มฟีเจอร์เพิ่มเติมเพื่ออำนวยความสะดวกในการใช้งาน ซึ่งรวมถึงสิ่งต่อไปนี้

  • การจัดรูปแบบชุดข้อมูล รวมถึงรูปแบบการสนทนาและรูปแบบคำสั่ง
  • การฝึกเฉพาะการเติมข้อความให้สมบูรณ์โดยไม่สนใจพรอมต์
  • การแพ็กชุดข้อมูลเพื่อการฝึกที่มีประสิทธิภาพมากขึ้น
  • รองรับการปรับแต่งอย่างละเอียดที่มีประสิทธิภาพด้านพารามิเตอร์ (PEFT) รวมถึง QLoRA
  • การเตรียมโมเดลและโทเค็นไนเซอร์สำหรับการปรับแต่งแบบสนทนา (เช่น การเพิ่มโทเค็นพิเศษ)

โค้ดต่อไปนี้จะโหลดโมเดลและโทเค็นไนเซอร์ Gemma จาก Hugging Face และเริ่มต้นการกำหนดค่าการหาปริมาณ

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

SFTTrainer รองรับการผสานรวมกับ peft โดยตรง ซึ่งช่วยให้ปรับแต่ง LLM ได้อย่างมีประสิทธิภาพโดยใช้ QLoRA คุณเพียงแค่ต้องสร้าง LoraConfig และมอบให้ผู้ฝึกสอน

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

ก่อนเริ่มการฝึก คุณต้องกำหนดไฮเปอร์พารามิเตอร์ที่ต้องการใช้ในอินสแตนซ์ SFTConfig

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_length=512,                         # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

ตอนนี้คุณมีองค์ประกอบทุกอย่างที่จำเป็นในการสร้าง SFTTrainer เพื่อเริ่มการฝึกโมเดลแล้ว

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

เริ่มการฝึกโดยเรียกใช้เมธอด train()

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

ก่อนที่จะทดสอบโมเดล โปรดตรวจสอบว่าได้เพิ่มหน่วยความจำแล้ว

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

เมื่อใช้ QLoRA คุณจะฝึกเฉพาะอแดปเตอร์ ไม่ใช่โมเดลทั้งหมด ซึ่งหมายความว่าเมื่อบันทึกโมเดลระหว่างการฝึก คุณจะบันทึกเฉพาะน้ำหนักของอแดปเตอร์ ไม่ใช่โมเดลทั้งหมด หากต้องการบันทึกโมเดลแบบเต็ม ซึ่งจะช่วยให้ใช้งานกับสแต็กการแสดงผล เช่น vLLM หรือ TGI ได้ง่ายขึ้น คุณสามารถผสานน้ำหนักของอแดปเตอร์เข้ากับน้ำหนักของโมเดลได้โดยใช้เมธอด merge_and_unload จากนั้นบันทึกโมเดลด้วยเมธอด save_pretrained ซึ่งจะบันทึกโมเดลเริ่มต้นที่ใช้สำหรับการอนุมานได้

from peft import PeftModel

# Load Model base model
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

ทดสอบการอนุมานโมเดลและสร้างการค้นหา SQL

หลังจากฝึกแล้ว คุณจะต้องประเมินและทดสอบโมเดล คุณโหลดตัวอย่างต่างๆ จากชุดข้อมูลทดสอบและประเมินโมเดลในตัวอย่างเหล่านั้นได้

import torch
from transformers import pipeline

model_id = "gemma-text-to-sql"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch_dtype,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

มาโหลดตัวอย่างแบบสุ่มจากชุดข้อมูลทดสอบและสร้างคำสั่ง SQL กัน

from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"])-1)
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

สรุปและขั้นตอนถัดไป

บทแนะนำนี้ครอบคลุมวิธีปรับแต่งโมเดล Gemma โดยใช้ TRL และ QLoRA ดูเอกสารต่อไปนี้