คู่มือนี้จะแนะนำวิธีปรับแต่ง Gemma ในชุดข้อมูลข้อความเป็น SQL ที่กําหนดเองโดยใช้ Transformers และ TRL ของ Hugging Face คุณจะได้เรียนรู้:
- Quantized Low-Rank Adaptation (QLoRA) คืออะไร
- ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์
- สร้างและเตรียมชุดข้อมูลการปรับแต่ง
- ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer
- ทดสอบการอนุมานโมเดลและสร้างการค้นหา SQL
Quantized Low-Rank Adaptation (QLoRA) คืออะไร
คู่มือนี้จะแสดงการใช้ Quantized Low-Rank Adaptation (QLoRA) ซึ่งเป็นวิธีการยอดนิยมในการปรับแต่ง LLM อย่างมีประสิทธิภาพ เนื่องจากช่วยลดความต้องการทรัพยากรการประมวลผลไปพร้อมกับรักษาประสิทธิภาพให้สูง ใน QloRA โมเดลที่ผ่านการฝึกอบรมล่วงหน้าจะได้รับการแปลงเป็น 4 บิตและน้ำหนักจะได้รับการตรึงไว้ จากนั้นจะแนบเลเยอร์อะแดปเตอร์ที่ฝึกได้ (LoRA) และฝึกเฉพาะเลเยอร์อะแดปเตอร์ หลังจากนั้น คุณสามารถผสานน้ำหนักของอะแดปเตอร์เข้ากับโมเดลพื้นฐานหรือเก็บไว้เป็นอะแดปเตอร์แยกต่างหากก็ได้
ตั้งค่าสภาพแวดล้อมในการพัฒนาซอฟต์แวร์
ขั้นตอนแรกคือการติดตั้งไลบรารี Hugging Face ซึ่งรวมถึง TRL และชุดข้อมูลเพื่อปรับแต่งโมเดลแบบเปิด รวมถึง RLHF และเทคนิคการจัดแนวที่แตกต่างกัน
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard
# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"
# Install Hugging Face libraries
%pip install --upgrade \
"datasets==3.3.2" \
"accelerate==1.4.0" \
"evaluate==0.4.3" \
"bitsandbytes==0.45.3" \
"trl==0.15.2" \
"peft==0.14.0" \
protobuf \
sentencepiece
# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn
หมายเหตุ: หากใช้ GPU ที่มีสถาปัตยกรรม Ampere (เช่น NVIDIA L4) ขึ้นไป คุณจะใช้ Flash attention ได้ Flash Attention เป็นวิธีการที่เพิ่มความเร็วในการคำนวณและลดการใช้หน่วยความจําจากแบบ 2 เท่าเป็นแบบเชิงเส้นตามความยาวของลําดับ ซึ่งทําให้การฝึกเร็วขึ้นถึง 3 เท่า ดูข้อมูลเพิ่มเติมได้ที่ FlashAttention
คุณต้องยอมรับข้อกำหนดในการใช้งาน Gemma ก่อนจึงจะเริ่มการฝึกได้ คุณสามารถยอมรับใบอนุญาตใน Hugging Face ได้โดยคลิกปุ่ม "ยอมรับและเข้าถึงที่เก็บ" ในหน้าโมเดลที่ http://huggingface.co/google/gemma-3-1b-pt
หลังจากยอมรับใบอนุญาตแล้ว คุณต้องมีโทเค็น Hugging Face ที่ถูกต้องจึงจะเข้าถึงโมเดลได้ หากเรียกใช้ภายใน Google Colab คุณจะใช้โทเค็น Hugging Face ได้อย่างปลอดภัยโดยใช้ข้อมูลลับของ Colab หรือจะตั้งค่าโทเค็นในเมธอด login
โดยตรงก็ได้ ตรวจสอบว่าโทเค็นของคุณมีสิทธิ์เขียนด้วย เนื่องจากคุณจะต้องพุชโมเดลไปยัง Hub ระหว่างการฝึก
from google.colab import userdata
from huggingface_hub import login
# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)
สร้างและเตรียมชุดข้อมูลสำหรับการปรับแต่ง
เมื่อปรับแต่ง LLM คุณต้องทราบกรณีการใช้งานและงานที่คุณต้องการแก้ปัญหา ซึ่งจะช่วยคุณสร้างชุดข้อมูลเพื่อปรับแต่งโมเดล หากยังไม่ได้กําหนดกรณีการใช้งาน คุณอาจต้องกลับไปที่กระดานวาดภาพ
ตัวอย่างเช่น คู่มือนี้มุ่งเน้นที่กรณีการใช้งานต่อไปนี้
- ปรับแต่งภาษาธรรมชาติเป็นโมเดล SQL เพื่อผสานรวมกับเครื่องมือวิเคราะห์ข้อมูลอย่างราบรื่น โดยมีวัตถุประสงค์เพื่อลดเวลาและความเชี่ยวชาญที่จําเป็นสําหรับการสร้างคําค้นหา SQL อย่างมาก ซึ่งช่วยให้ผู้ใช้ที่ไม่เชี่ยวชาญด้านเทคนิคสามารถดึงข้อมูลเชิงลึกที่มีประโยชน์จากข้อมูลได้
การเปลี่ยนข้อความเป็น SQL อาจเป็น Use Case ที่เหมาะสําหรับการปรับแต่ง LLM เนื่องจากเป็นงานที่ซับซ้อนซึ่งต้องใช้ความรู้ (ภายใน) เกี่ยวกับข้อมูลและภาษา SQL เป็นอย่างมาก
เมื่อพิจารณาแล้วว่าการปรับแต่งเป็นวิธีแก้ปัญหาที่เหมาะสม คุณจะต้องมีชุดข้อมูลเพื่อปรับแต่ง ชุดข้อมูลควรเป็นชุดการสาธิตงานที่หลากหลายซึ่งคุณต้องการแก้ปัญหา การสร้างชุดข้อมูลดังกล่าวมีหลายวิธี ดังนี้
- การใช้ชุดข้อมูลโอเพนซอร์สที่มีอยู่ เช่น Spider
- การใช้ชุดข้อมูลสังเคราะห์ที่สร้างโดย LLM เช่น Alpaca
- การใช้ชุดข้อมูลที่มนุษย์สร้างขึ้น เช่น Dolly
- ใช้วิธีการต่างๆ ร่วมกัน เช่น Orca
แต่ละวิธีมีทั้งข้อดีและข้อเสีย และขึ้นอยู่กับงบประมาณ เวลา และข้อกำหนดด้านคุณภาพ เช่น การใช้ชุดข้อมูลที่มีอยู่นั้นง่ายที่สุด แต่อาจไม่เหมาะกับกรณีการใช้งานที่เฉพาะเจาะจงของคุณ ส่วนการใช้ผู้เชี่ยวชาญด้านโดเมนอาจมีความแม่นยำมากที่สุด แต่ก็อาจใช้เวลานานและค่าใช้จ่ายสูง นอกจากนี้ คุณยังรวมวิธีการต่างๆ เพื่อสร้างชุดข้อมูลคำสั่งได้ ดังที่แสดงใน Orca: Progressive Learning from Complex Explanation Traces of GPT-4
คู่มือนี้ใช้ชุดข้อมูลที่มีอยู่ (philschmid/gretel-synthetic-text-to-sql) ซึ่งเป็นชุดข้อมูล Text-to-SQL สังเคราะห์คุณภาพสูงที่มีคําสั่งภาษาธรรมชาติ คําจํากัดความสคีมา การให้เหตุผล และการค้นหา SQL ที่เกี่ยวข้อง
Hugging Face TRL รองรับการสร้างเทมเพลตรูปแบบชุดข้อมูลการสนทนาโดยอัตโนมัติ ซึ่งหมายความว่าคุณจะต้องแปลงชุดข้อมูลให้เป็นออบเจ็กต์ JSON ที่ถูกต้องเท่านั้น และ trl
จะจัดการเรื่องเทมเพลตและจัดรูปแบบให้ถูกต้อง
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
philschmid/gretel-synthetic-text-to-sql มีตัวอย่างมากกว่า 100,000 รายการ เราได้ลดขนาดตัวอย่างมาใช้เพียง 10,000 รายการเพื่อให้คำแนะนำมีขนาดเล็ก
ตอนนี้คุณใช้คลังชุดข้อมูล Hugging Face เพื่อโหลดชุดข้อมูลและสร้างเทมเพลตพรอมต์เพื่อรวมคำสั่งภาษาที่เป็นธรรมชาติ คําจํากัดความของสคีมา และเพิ่มข้อความของระบบสําหรับผู้ช่วยได้แล้ว
from datasets import load_dataset
# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""
# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.
<SCHEMA>
{context}
</SCHEMA>
<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
return {
"messages": [
# {"role": "system", "content": system_message},
{"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
{"role": "assistant", "content": sample["sql"]}
]
}
# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))
# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)
# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])
ปรับแต่ง Gemma โดยใช้ TRL และ SFTTrainer
ตอนนี้คุณก็พร้อมที่จะปรับแต่งโมเดลแล้ว Hugging Face TRL SFTTrainer ช่วยให้คุณควบคุมการปรับแต่ง LLM แบบเปิดได้อย่างง่ายดาย SFTTrainer
เป็นคลาสย่อยของ Trainer
จากไลบรารี transformers
และรองรับฟีเจอร์เดียวกันทั้งหมด ซึ่งรวมถึงการบันทึก การประเมิน และการตรวจสอบจุดตรวจสอบ แต่เพิ่มฟีเจอร์อื่นๆ ที่ช่วยให้ใช้งานได้ง่ายขึ้น ซึ่งได้แก่
- การจัดรูปแบบชุดข้อมูล รวมถึงรูปแบบการสนทนาและคำสั่ง
- การฝึกอบรมเกี่ยวกับการดำเนินการเสร็จสมบูรณ์เท่านั้น โดยจะไม่สนใจพรอมต์
- การแพ็กชุดข้อมูลเพื่อให้การฝึกมีประสิทธิภาพมากขึ้น
- การรองรับการปรับแต่งแบบละเอียดที่มีประสิทธิภาพของพารามิเตอร์ (PEFT) รวมถึง QloRA
- เตรียมโมเดลและตัวแยกวิเคราะห์เพื่อปรับแต่งการสนทนาแบบละเอียด (เช่น การเพิ่มโทเค็นพิเศษ)
โค้ดต่อไปนี้จะโหลดโมเดล Gemma และตัวแยกวิเคราะห์จาก Hugging Face และเริ่มต้นการกําหนดค่าการแปลงค่า
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig
# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`
# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
model_class = AutoModelForCausalLM
else:
model_class = AutoModelForImageTextToText
# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16
# Define model init arguments
model_kwargs = dict(
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
device_map="auto", # Let torch decide how to load the model
)
# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)
# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template
SFTTrainer
รองรับการผสานรวมแบบเนทีฟกับ peft
ซึ่งช่วยให้ปรับแต่ง LLM ได้อย่างมีประสิทธิภาพโดยใช้ QLoRA คุณเพียงแค่สร้าง LoraConfig
แล้วส่งให้เทรนเนอร์
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)
คุณต้องกําหนดไฮเปอร์พารามิเตอร์ที่ต้องการใช้ในอินสแตนซ์ SFTConfig
ก่อนจึงจะเริ่มการฝึกได้
from trl import SFTConfig
args = SFTConfig(
output_dir="gemma-text-to-sql", # directory to save and repository id
max_seq_length=512, # max sequence length for model and packing of the dataset
packing=True, # Groups multiple samples in the dataset into a single sequence
num_train_epochs=3, # number of training epochs
per_device_train_batch_size=1, # batch size per device during training
gradient_accumulation_steps=4, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=10, # log every 10 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
fp16=True if torch_dtype == torch.float16 else False, # use float16 precision
bf16=True if torch_dtype == torch.bfloat16 else False, # use bfloat16 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="tensorboard", # report metrics to tensorboard
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": True, # Add EOS token as separator token between examples
}
)
ตอนนี้คุณมีองค์ประกอบพื้นฐานทั้งหมดในการสร้าง SFTTrainer
เพื่อเริ่มการฝึกโมเดลแล้ว
from trl import SFTTrainer
# Create Trainer object
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
peft_config=peft_config,
processing_class=tokenizer
)
เริ่มการฝึกโดยเรียกใช้เมธอด train()
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()
# Save the final model again to the Hugging Face Hub
trainer.save_model()
อย่าลืมเพิ่มพื้นที่ว่างในหน่วยความจำก่อนจึงจะทดสอบโมเดลได้
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
เมื่อใช้ QLoRA คุณจะฝึกเฉพาะอะแดปเตอร์เท่านั้น ไม่ใช่ทั้งโมเดล ซึ่งหมายความว่าเมื่อบันทึกโมเดลระหว่างการฝึก คุณบันทึกเฉพาะน้ำหนักของอะแดปเตอร์เท่านั้น ไม่ใช่โมเดลทั้งหมด หากต้องการบันทึกโมเดลแบบเต็ม ซึ่งทําให้ใช้งานกับสแต็กการแสดงโฆษณา เช่น vLLM หรือ TGI ได้ง่ายขึ้น คุณสามารถผสานน้ำหนักของอะแดปเตอร์เข้ากับน้ำหนักของโมเดลได้โดยใช้เมธอด merge_and_unload
จากนั้นบันทึกโมเดลด้วยเมธอด save_pretrained
ซึ่งจะบันทึกโมเดลเริ่มต้นไว้เพื่อใช้สำหรับการอนุมาน
from peft import PeftModel
# Load Model base model
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)
# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")
ทดสอบการอนุมานโมเดลและสร้างการค้นหา SQL
หลังจากการฝึกเสร็จสิ้นแล้ว คุณจะต้องประเมินและทดสอบโมเดล คุณสามารถโหลดตัวอย่างต่างๆ จากชุดทดสอบและประเมินโมเดลกับตัวอย่างเหล่านั้นได้
import torch
from transformers import pipeline
model_id = "gemma-text-to-sql"
# Load Model with PEFT adapter
model = model_class.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch_dtype,
attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
มาโหลดตัวอย่างแบบสุ่มจากชุดข้อมูลทดสอบและสร้างคําสั่ง SQL กัน
from random import randint
import re
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]
# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)
# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)
# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
สรุปและขั้นตอนถัดไป
บทแนะนำนี้ครอบคลุมวิธีปรับแต่งโมเดล Gemma โดยใช้ TRL และ QLoRA โปรดอ่านเอกสารต่อไปนี้