Affiner Gemma à l'aide de Hugging Face Transformers et de QloRA

Ce guide explique comment affiner Gemma sur un ensemble de données texte-SQL personnalisé à l'aide de Transformers et de TRL de Hugging Face. Vous allez découvrir comment :

  • Qu'est-ce que l'adaptation à rang faible quantifiée (QLoRA) ?
  • Configurer l'environnement de développement
  • Créer et préparer l'ensemble de données de réglage
  • Affiner Gemma à l'aide de TRL et de SFTTrainer
  • Tester l'inférence de modèle et générer des requêtes SQL

Qu'est-ce que l'adaptation à rang faible quantifiée (QLoRA) ?

Ce guide présente l'utilisation de l'adaptation faible à rang quantifié (QLoRA), qui s'est imposée comme une méthode populaire pour affiner efficacement les LLM, car elle réduit les exigences en ressources de calcul tout en maintenant des performances élevées. Dans QloRA, le modèle pré-entraîné est quantifié à 4 bits et les poids sont figés. Des couches d'adaptateur enregistrables (LoRA) sont ensuite associées, et seules les couches d'adaptateur sont entraînées. Ensuite, les poids de l'adaptateur peuvent être fusionnés avec le modèle de base ou conservés en tant qu'adaptateur distinct.

Configurer l'environnement de développement

La première étape consiste à installer les bibliothèques Hugging Face, y compris TRL, et les ensembles de données pour affiner le modèle ouvert, y compris différentes techniques d'alignement et de RLHF.

# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  protobuf \
  sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn

Remarque: Si vous utilisez un GPU avec architecture Ampere (comme NVIDIA L4) ou une version ultérieure, vous pouvez utiliser l'attention Flash. Flash Attention est une méthode qui accélère considérablement les calculs et réduit l'utilisation de la mémoire de façon linéaire en fonction de la longueur de la séquence, ce qui accélère l'entraînement jusqu'à trois fois. Pour en savoir plus, consultez FlashAttention.

Avant de pouvoir commencer l'entraînement, vous devez vous assurer d'avoir accepté les conditions d'utilisation de Gemma. Vous pouvez accepter la licence sur Hugging Face en cliquant sur le bouton "Accepter et accéder au dépôt" sur la page du modèle à l'adresse suivante: http://huggingface.co/google/gemma-3-1b-pt

Une fois que vous avez accepté la licence, vous avez besoin d'un jeton Hugging Face valide pour accéder au modèle. Si vous exécutez votre code dans un Google Colab, vous pouvez utiliser votre jeton Hugging Face de manière sécurisée à l'aide des secrets Colab. Sinon, vous pouvez définir le jeton directement dans la méthode login. Assurez-vous également que votre jeton dispose d'un accès en écriture, car vous envoyez votre modèle au hub pendant l'entraînement.

from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

Créer et préparer l'ensemble de données de réglage

Lorsque vous affinez des LLM, il est important de connaître votre cas d'utilisation et la tâche que vous souhaitez résoudre. Cela vous aide à créer un ensemble de données pour affiner votre modèle. Si vous n'avez pas encore défini votre cas d'utilisation, vous pouvez revenir à la planche à dessin.

Par exemple, ce guide se concentre sur le cas d'utilisation suivant:

  • Ajustez un modèle de conversion du langage naturel en langage SQL pour une intégration transparente dans un outil d'analyse de données. L'objectif est de réduire considérablement le temps et l'expertise requis pour générer des requêtes SQL, afin que même les utilisateurs non techniques puissent extraire des insights pertinents à partir des données.

La conversion de texte en SQL peut être un bon cas d'utilisation pour affiner les LLM, car il s'agit d'une tâche complexe qui nécessite de nombreuses connaissances (internes) sur les données et le langage SQL.

Une fois que vous avez déterminé que le paramétrage est la solution appropriée, vous avez besoin d'un ensemble de données pour le paramétrer. L'ensemble de données doit être un ensemble varié de démonstrations de la ou des tâches que vous souhaitez résoudre. Il existe plusieurs façons de créer un tel ensemble de données, par exemple:

  • Utiliser des ensembles de données Open Source existants, tels que Spider
  • Utiliser des ensembles de données synthétiques créés par des LLM, tels que Alpaca
  • Utiliser des ensembles de données créés par des humains, comme Dolly.
  • En utilisant une combinaison de méthodes, comme Orca

Chacune de ces méthodes présente ses propres avantages et inconvénients, et dépend du budget, du délai et des exigences de qualité. Par exemple, utiliser un ensemble de données existant est la méthode la plus simple, mais il est possible qu'il ne soit pas adapté à votre cas d'utilisation spécifique. En revanche, faire appel à des experts du domaine peut être la méthode la plus précise, mais elle peut être longue et coûteuse. Il est également possible de combiner plusieurs méthodes pour créer un ensemble de données d'instructions, comme indiqué dans Orca: apprentissage progressif à partir des traces d'explications complexes de GPT-4.

Ce guide utilise un ensemble de données déjà existant (philschmid/gretel-synthetic-text-to-sql), un ensemble de données texte-SQL synthétique de haute qualité comprenant des instructions en langage naturel, des définitions de schéma, un raisonnement et la requête SQL correspondante.

Hugging Face TRL permet de créer automatiquement des modèles pour les formats d'ensembles de données de conversation. Cela signifie que vous n'avez qu'à convertir votre ensemble de données en objets JSON appropriés, et trl se charge de créer des modèles et de le mettre au bon format.

{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

philschmid/gretel-synthetic-text-to-sql contient plus de 100 000 échantillons. Pour réduire la taille du guide, il est réduit à 10 000 échantillons.

Vous pouvez désormais utiliser la bibliothèque Hugging Face Datasets pour charger l'ensemble de données et créer un modèle d'invite afin de combiner l'instruction en langage naturel, la définition du schéma et d'ajouter un message système pour votre assistant.

from datasets import load_dataset

# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])

Affiner Gemma à l'aide de TRL et de SFTTrainer

Vous êtes maintenant prêt à affiner votre modèle. Le SFTTrainer de Hugging Face TRL permet de superviser facilement l'ajustement fin des LLM ouverts. SFTTrainer est une sous-classe de Trainer de la bibliothèque transformers. Elle est compatible avec toutes les mêmes fonctionnalités, y compris la journalisation, l'évaluation et le point de contrôle, mais ajoute des fonctionnalités supplémentaires pour améliorer la qualité de vie, y compris les suivantes:

  • Mise en forme des ensembles de données, y compris les formats conversationnel et d'instruction
  • Entraînement sur les finalisations uniquement, en ignorant les requêtes
  • Empaqueter des ensembles de données pour un entraînement plus efficace
  • Prise en charge de l'optimisation du réglage des paramètres (PEFT, Parameter-efficient Fine-Tuning), y compris QloRA
  • Préparation du modèle et du tokenizer pour le réglage fin des conversations (par exemple, ajout de jetons spéciaux)

Le code suivant charge le modèle et le tokenizer Gemma à partir de Hugging Face, puis initialise la configuration de quantification.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

SFTTrainer est compatible avec une intégration native avec peft, ce qui permet de régler facilement et efficacement les LLM à l'aide de QLoRA. Il vous suffit de créer un LoraConfig et de le fournir au formateur.

from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

Avant de pouvoir commencer votre entraînement, vous devez définir l'hyperparamètre que vous souhaitez utiliser dans une instance SFTConfig.

from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_seq_length=512,                     # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

Vous disposez désormais de tous les éléments nécessaires pour créer votre SFTTrainer et commencer l'entraînement de votre modèle.

from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

Démarrez l'entraînement en appelant la méthode train().

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

Avant de pouvoir tester votre modèle, assurez-vous de libérer la mémoire.

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

Lorsque vous utilisez QLoRA, vous n'entraînez que les adaptateurs, et non le modèle complet. Cela signifie que lorsque vous enregistrez le modèle pendant l'entraînement, vous n'enregistrez que les poids de l'adaptateur et non le modèle complet. Si vous souhaitez enregistrer le modèle complet, ce qui le rend plus facile à utiliser avec des piles de diffusion telles que vLLM ou TGI, vous pouvez fusionner les poids de l'adaptateur avec les poids du modèle à l'aide de la méthode merge_and_unload, puis enregistrer le modèle avec la méthode save_pretrained. Cela permet d'enregistrer un modèle par défaut, qui peut être utilisé pour l'inférence.

from peft import PeftModel

# Load Model base model
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoTokenizer.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

Tester l'inférence de modèle et générer des requêtes SQL

Une fois l'entraînement terminé, vous devez évaluer et tester votre modèle. Vous pouvez charger différents échantillons à partir de l'ensemble de données de test et évaluer le modèle sur ces échantillons.

import torch
from transformers import pipeline

model_id = "gemma-text-to-sql"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch_dtype,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Chargeons un échantillon aléatoire de l'ensemble de données de test et générons une commande SQL.

from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Résumé et prochaines étapes

Ce tutoriel vous a expliqué comment affiner un modèle Gemma à l'aide de TRL et de QLoRA. Consultez ensuite les documents suivants: