Точная настройка Джеммы с помощью Hugging Face Transformers и QloRA

В этом руководстве рассказывается, как точно настроить Gemma для пользовательского набора данных преобразования текста в SQL с помощью Hugging Face Transformers и TRL . Вы узнаете:

  • Что такое квантованная адаптация низкого ранга (QLoRA)
  • Настройка среды разработки
  • Создайте и подготовьте набор данных для точной настройки.
  • Точная настройка Джеммы с помощью TRL и SFTTrainer.
  • Тестирование вывода модели и создание SQL-запросов

Что такое квантованная адаптация низкого ранга (QLoRA)

В этом руководстве демонстрируется использование квантовой низкоранговой адаптации (QLoRA) , которая стала популярным методом эффективной точной настройки LLM, поскольку она снижает требования к вычислительным ресурсам при сохранении высокой производительности. В QloRA предварительно обученная модель квантуется до 4 бит, а веса замораживаются. Затем прикрепляются обучаемые слои-адаптеры (LoRA) и обучаются только слои-адаптеры. После этого грузики адаптера можно объединить с базовой моделью или сохранить как отдельный адаптер.

Настройка среды разработки

Первым шагом является установка библиотек Hugging Face, включая TRL, и наборов данных для точной настройки открытой модели, включая различные методы RLHF и выравнивания.

# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  protobuf \
  sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn

Примечание. Если вы используете графический процессор с архитектурой Ampere (например, NVIDIA L4) или новее, вы можете использовать Flash-внимание. Flash Attention — это метод, который значительно ускоряет вычисления и сокращает использование памяти с квадратичного до линейного по длине последовательности, что приводит к ускорению обучения до 3 раз. Узнайте больше на FlashAttention .

Прежде чем вы сможете начать обучение, вам необходимо убедиться, что вы приняли условия использования Gemma. Вы можете принять лицензию на Hugging Face , нажав кнопку «Принять и получить доступ к репозиторию» на странице модели по адресу: http://huggingface.co/google/gemma-3-1b-pt.

После того как вы приняли лицензию, вам понадобится действительный токен Hugging Face Token для доступа к модели. Если вы работаете внутри Google Colab, вы можете безопасно использовать свой токен Hugging Face, используя секреты Colab, в противном случае вы можете установить токен непосредственно в методе login . Убедитесь, что у вашего токена также есть доступ на запись, когда вы отправляете свою модель в хаб во время обучения.

from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

Создайте и подготовьте набор данных для точной настройки.

При точной настройке LLM важно знать ваш вариант использования и задачу, которую вы хотите решить. Это поможет вам создать набор данных для точной настройки вашей модели. Если вы еще не определили свой вариант использования, возможно, вам захочется вернуться к чертежной доске.

В качестве примера в этом руководстве рассматривается следующий вариант использования:

  • Точная настройка естественного языка на модель SQL для полной интеграции в инструмент анализа данных. Цель состоит в том, чтобы значительно сократить время и знания, необходимые для генерации SQL-запросов, позволяя даже нетехническим пользователям извлекать значимую информацию из данных.

Преобразование текста в SQL может быть хорошим вариантом использования для тонкой настройки LLM, поскольку это сложная задача, требующая большого количества (внутренних) знаний о данных и языке SQL.

После того как вы определили, что точная настройка является правильным решением, вам понадобится набор данных для точной настройки. Набор данных должен представлять собой разнообразный набор демонстраций задач, которые вы хотите решить. Существует несколько способов создания такого набора данных, в том числе:

  • Использование существующих наборов данных с открытым исходным кодом, таких как Spider
  • Использование синтетических наборов данных, созданных LLM, таких как Alpaca.
  • Использование наборов данных, созданных людьми, например Dolly .
  • Использование комбинации методов, таких как Orca

Каждый из методов имеет свои преимущества и недостатки и зависит от бюджета, времени и требований к качеству. Например, использование существующего набора данных является самым простым, но может не быть адаптировано к вашему конкретному случаю использования, тогда как использование экспертов в предметной области может быть наиболее точным, но может занять много времени и дорого. Также возможно объединить несколько методов для создания набора данных инструкций, как показано в Orca: Progressive Learning from Complex Explanation Traces GPT-4.

В этом руководстве используется уже существующий набор данных ( philschmid/gretel-synthetic-text-to-sql ), высококачественный синтетический набор данных преобразования текста в SQL, включающий инструкции на естественном языке, определения схем, рассуждения и соответствующий SQL-запрос.

Hugging Face TRL поддерживает автоматическое создание шаблонов форматов наборов данных разговоров. Это означает, что вам нужно только преобразовать ваш набор данных в правильные объекты JSON, а trl позаботится о создании шаблонов и приведении его в правильный формат.

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

Файл philschmid/gretel-synthetic-text-to-sql содержит более 100 тыс. образцов. Чтобы сохранить руководство небольшим, его дискретизация уменьшена до 10 000 сэмплов.

Теперь вы можете использовать библиотеку наборов данных Hugging Face, чтобы загрузить набор данных и создать шаблон подсказки, объединяющий инструкции на естественном языке, определение схемы и добавление системного сообщения для вашего помощника.

from datasets import load_dataset

# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])

Точная настройка Джеммы с помощью TRL и SFTTrainer.

Теперь вы готовы к точной настройке вашей модели. Hugging Face TRL SFTTrainer позволяет легко контролировать точную настройку открытых LLM. SFTTrainer является подклассом Trainer из библиотеки transformers и поддерживает все те же функции, включая ведение журнала, оценку и контрольные точки, но добавляет дополнительные функции качества жизни, в том числе:

  • Форматирование набора данных, включая диалоговые форматы и форматы инструкций.
  • Обучение только завершению, игнорируя подсказки
  • Упаковка наборов данных для более эффективного обучения
  • Поддержка точной настройки с эффективным использованием параметров (PEFT), включая QloRA
  • Подготовка модели и токенизатора для интерактивной тонкой настройки (например, добавления специальных токенов).

Следующий код загружает модель и токенизатор Gemma из Hugging Face и инициализирует конфигурацию квантования.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

SFTTrainer поддерживает встроенную интеграцию с peft , что упрощает эффективную настройку LLM с помощью QLoRA. Вам нужно только создать LoraConfig и предоставить его тренеру.

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

Прежде чем вы сможете начать обучение, вам необходимо определить гиперпараметр, который вы хотите использовать в экземпляре SFTConfig .

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_seq_length=512,                     # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

Теперь у вас есть все необходимые строительные блоки для создания SFTTrainer и начала обучения вашей модели.

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

Начните обучение, вызвав метод train() .

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

Прежде чем протестировать свою модель, обязательно освободите память.

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

При использовании QLoRA вы обучаете только адаптеры, а не полную модель. Это означает, что при сохранении модели во время тренировки вы сохраняете только веса адаптера, а не всю модель. Если вы хотите сохранить полную модель, что упрощает ее использование со стеками обслуживания, такими как vLLM или TGI, вы можете объединить веса адаптера с весами модели с помощью метода merge_and_unload , а затем сохранить модель с помощью метода save_pretrained . При этом сохраняется модель по умолчанию, которую можно использовать для вывода.

from peft import PeftModel

# Load Model base model
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

Тестирование вывода модели и создание SQL-запросов

После завершения обучения вам нужно будет оценить и протестировать свою модель. Вы можете загрузить различные образцы из тестового набора данных и оценить модель на этих образцах.

import torch
from transformers import pipeline

model_id = "gemma-text-to-sql"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch_dtype,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Давайте загрузим случайную выборку из тестового набора данных и сгенерируем команду SQL.

from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Резюме и следующие шаги

В этом руководстве рассказывается, как точно настроить модель Gemma с помощью TRL и QLoRA. Далее ознакомьтесь со следующими документами: