Affiner EmbeddingGemma

 Afficher sur ai.google.dev Exécuter dans Google Colab  Exécuter dans Kaggle Ouvrir dans Vertex AI Afficher la source sur GitHub

L'affinage permet de combler le fossé entre la compréhension polyvalente d'un modèle et la précision spécialisée et performante requise par votre application. Comme aucun modèle n'est parfait pour toutes les tâches, l'affinage l'adapte à votre domaine spécifique.

Imaginez que votre entreprise, "Shibuya Financial", propose divers produits financiers complexes tels que des fonds d'investissement, des comptes NISA (comptes d'épargne bénéficiant d'avantages fiscaux) et des prêts immobiliers. Votre équipe du service client utilise une base de connaissances interne pour trouver rapidement des réponses aux questions des clients.

Configuration

Avant de commencer ce tutoriel, effectuez les étapes suivantes :

  • Pour accéder à EmbeddingGemma, connectez-vous à Hugging Face et sélectionnez Accepter la licence pour un modèle Gemma.
  • Générez un jeton d'accès Hugging Face et utilisez-le pour vous connecter depuis Colab.

Ce notebook s'exécutera sur un processeur ou un GPU.

Installer des packages Python

Installez les bibliothèques requises pour exécuter le modèle EmbeddingGemma et générer des embeddings. Sentence Transformers est un framework Python pour les embeddings de texte et d'image. Pour en savoir plus, consultez la documentation Sentence Transformers.

pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

Une fois la licence acceptée, vous avez besoin d'un jeton Hugging Face valide pour accéder au modèle.

# Login into Hugging Face Hub
from huggingface_hub import login
login()

Charger le modèle

Utilisez les bibliothèques sentence-transformers pour créer une instance d'une classe de modèle avec EmbeddingGemma.

import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696

Préparer l'ensemble de données de réglage fin

C'est la partie la plus importante. Vous devez créer un ensemble de données qui apprend au modèle ce que signifie "similaire" dans votre contexte spécifique. Ces données sont souvent structurées sous forme de triplets (ancrage, positif, négatif).

  • Ancre : requête ou phrase d'origine.
  • Positif : phrase sémantiquement très similaire ou identique à l'ancrage.
  • Négatif : phrase portant sur un thème connexe, mais sémantiquement distinct.

Dans cet exemple, nous n'avons préparé que trois triplets, mais pour une application réelle, vous auriez besoin d'un ensemble de données beaucoup plus volumineux pour obtenir de bons résultats.

from datasets import Dataset

dataset = [
    ["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
    ["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
    ["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
]

# Convert the list-based dataset into a list of dictionaries.
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]

# Create a Hugging Face `Dataset` object from the list of dictionaries.
train_dataset = Dataset.from_list(data_as_dicts)
print(train_dataset)
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 3
})

Avant l'affinage

Une recherche sur "investissement sans impôt" peut avoir donné les résultats suivants, avec des scores de similarité :

  1. Document : Ouvrir un compte NISA (score : 0,45)
  2. Document : Ouvrir un compte d'épargne régulier (score : 0,48) <- Score similaire, potentiellement déroutant
  3. Document : Guide de demande de prêt immobilier (score : 0,42)
task_name = "STS"

def get_scores(query, documents):
  # Calculate embeddings by calling model.encode()
  query_embeddings = model.encode(query, prompt=task_name)
  doc_embeddings = model.encode(documents, prompt=task_name)

  # Calculate the embedding similarities
  similarities = model.similarity(query_embeddings, doc_embeddings)

  for idx, doc in enumerate(documents):
    print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])

query = "I want to start a tax-free installment investment, what should I do?"
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]

get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.45698774
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.48092696
Document:  Home Loan Application Guide -> 🤖 Score:  0.42127067

Formation

En utilisant un framework comme sentence-transformers en Python, le modèle de base apprend progressivement les subtiles distinctions de votre vocabulaire financier.

from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback

loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="my-embedding-gemma",
    # Optional training parameters:
    prompts=model.prompts[task_name],    # use model's prompt to train
    num_train_epochs=5,
    per_device_train_batch_size=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    # Optional tracking/debugging parameters:
    logging_steps=train_dataset.num_rows,
    report_to="none",
)

class MyCallback(TrainerCallback):
    "A callback that evaluates the model at the end of eopch"
    def __init__(self, evaluate):
        self.evaluate = evaluate # evaluate function

    def on_log(self, args, state, control, **kwargs):
        # Evaluate the model using text generation
        print(f"Step {state.global_step} finished. Running evaluation:")
        self.evaluate()

def evaluate():
  get_scores(query, documents)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
    callbacks=[MyCallback(evaluate)]
)
trainer.train()
Step 3 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.6449194
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.44123
Document:  Home Loan Application Guide -> 🤖 Score:  0.46752414
Step 6 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.68873787
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.34069622
Document:  Home Loan Application Guide -> 🤖 Score:  0.50065553
Step 9 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7148906
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.30480516
Document:  Home Loan Application Guide -> 🤖 Score:  0.52454984
Step 12 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.72614634
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.29255486
Document:  Home Loan Application Guide -> 🤖 Score:  0.5370023
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
TrainOutput(global_step=15, training_loss=0.009651281436261646, metrics={'train_runtime': 63.2486, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.237, 'total_flos': 0.0, 'train_loss': 0.009651281436261646, 'epoch': 5.0})

Après l'affinage

La même recherche donne désormais des résultats beaucoup plus clairs :

  1. Document : Ouvrir un compte NISA (score : 0,72) <- Beaucoup plus sûr
  2. Document : Ouvrir un compte d'épargne régulier (score : 0,28) <- Nettement moins pertinent
  3. Document : Guide de demande de prêt immobilier (score : 0,54)
get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913

Pour importer votre modèle dans le Hub Hugging Face, vous pouvez utiliser la méthode push_to_hub de la bibliothèque Sentence Transformers.

Importer votre modèle vous permet d'y accéder facilement pour l'inférence directement depuis le Hub, de le partager avec d'autres utilisateurs et de gérer les versions de votre travail. Une fois votre modèle importé, n'importe qui peut le charger avec une seule ligne de code, en référençant simplement son ID de modèle unique <username>/my-embedding-gemma.

# Push to Hub
model.push_to_hub("my-embedding-gemma")

Résumé et étapes suivantes

Vous avez maintenant appris à adapter un modèle EmbeddingGemma à un domaine spécifique en l'affinant avec la bibliothèque Sentence Transformers.

Découvrez ce que vous pouvez faire d'autre avec EmbeddingGemma :