EmbeddingGemma abstimmen

Auf ai.google.dev ansehen In Google Colab ausführen In Kaggle ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

Durch das Feinabstimmen wird die Lücke zwischen dem allgemeinen Verständnis eines Modells und der spezialisierten, leistungsstarken Genauigkeit, die für Ihre Anwendung erforderlich ist, geschlossen. Da kein einzelnes Modell für jede Aufgabe perfekt ist, wird es durch die Feinabstimmung an Ihre spezifische Domain angepasst.

Angenommen, Ihr Unternehmen „Shibuya Financial“ bietet verschiedene komplexe Finanzprodukte wie Investmentfonds, NISA-Konten (ein steuerbegünstigtes Sparkonto) und Baufinanzierungen an. Ihr Kundenserviceteam verwendet eine interne Wissensdatenbank, um schnell Antworten auf Kundenfragen zu finden.

Einrichtung

Führen Sie die folgenden Schritte aus, bevor Sie mit dieser Anleitung beginnen:

  • Sie erhalten Zugriff auf EmbeddingGemma, indem Sie sich bei Hugging Face anmelden und für ein Gemma-Modell Lizenz bestätigen auswählen.
  • Generieren Sie ein Hugging Face-Zugriffstoken und verwenden Sie es, um sich über Colab anzumelden.

Dieses Notebook kann entweder auf der CPU oder der GPU ausgeführt werden.

Python-Pakete installieren

Installieren Sie die Bibliotheken, die zum Ausführen des EmbeddingGemma-Modells und zum Generieren von Einbettungen erforderlich sind. Sentence Transformers ist ein Python-Framework für Text- und Bildeinbettungen. Weitere Informationen finden Sie in der Dokumentation zu Sentence Transformers.

pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

Nachdem Sie die Lizenz akzeptiert haben, benötigen Sie ein gültiges Hugging Face-Token, um auf das Modell zuzugreifen.

# Login into Hugging Face Hub
from huggingface_hub import login
login()

Modell laden

Verwenden Sie die sentence-transformers-Bibliotheken, um eine Instanz einer Modellklasse mit EmbeddingGemma zu erstellen.

import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696

Dataset für die Feinabstimmung vorbereiten

Das ist der wichtigste Teil. Sie müssen ein Dataset erstellen, mit dem das Modell lernt, was in Ihrem spezifischen Kontext „ähnlich“ bedeutet. Diese Daten sind oft als Tripletts strukturiert: (Anker, positiv, negativ).

  • Anker: Die ursprüngliche Anfrage oder der ursprüngliche Satz.
  • Positiv: Ein Satz, der dem Anker semantisch sehr ähnlich oder identisch ist.
  • Negativ: Ein Satz, der sich auf ein verwandtes Thema bezieht, aber semantisch unterschiedlich ist.

In diesem Beispiel haben wir nur drei Tupel vorbereitet. Für eine echte Anwendung wäre jedoch ein viel größerer Datensatz erforderlich, um gute Ergebnisse zu erzielen.

from datasets import Dataset

dataset = [
    ["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
    ["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
    ["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
]

# Convert the list-based dataset into a list of dictionaries.
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]

# Create a Hugging Face `Dataset` object from the list of dictionaries.
train_dataset = Dataset.from_list(data_as_dicts)
print(train_dataset)
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 3
})

Vor der Feinabstimmung

Eine Suche nach „steuerfreie Anlage“ hätte möglicherweise die folgenden Ergebnisse mit Ähnlichkeitswerten geliefert:

  1. Dokument: NISA-Konto eröffnen (Punktzahl: 0,45)
  2. Dokument: Eröffnung eines Girokontos (Score: 0,48) <- Ähnlicher Score, potenziell verwirrend
  3. Dokument: Leitfaden für Hausdarlehensanträge (Punktzahl: 0,42)
task_name = "STS"

def get_scores(query, documents):
  # Calculate embeddings by calling model.encode()
  query_embeddings = model.encode(query, prompt=task_name)
  doc_embeddings = model.encode(documents, prompt=task_name)

  # Calculate the embedding similarities
  similarities = model.similarity(query_embeddings, doc_embeddings)

  for idx, doc in enumerate(documents):
    print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])

query = "I want to start a tax-free installment investment, what should I do?"
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]

get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.45698774
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.48092696
Document:  Home Loan Application Guide -> 🤖 Score:  0.42127067

Training

Wenn Sie ein Framework wie sentence-transformers in Python verwenden, lernt das Basismodell nach und nach die subtilen Unterschiede in Ihrem Finanzvokabular.

from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback

loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="my-embedding-gemma",
    # Optional training parameters:
    prompts=model.prompts[task_name],    # use model's prompt to train
    num_train_epochs=5,
    per_device_train_batch_size=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    # Optional tracking/debugging parameters:
    logging_steps=train_dataset.num_rows,
    report_to="none",
)

class MyCallback(TrainerCallback):
    "A callback that evaluates the model at the end of eopch"
    def __init__(self, evaluate):
        self.evaluate = evaluate # evaluate function

    def on_log(self, args, state, control, **kwargs):
        # Evaluate the model using text generation
        print(f"Step {state.global_step} finished. Running evaluation:")
        self.evaluate()

def evaluate():
  get_scores(query, documents)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
    callbacks=[MyCallback(evaluate)]
)
trainer.train()
Step 3 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.6449194
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.44123
Document:  Home Loan Application Guide -> 🤖 Score:  0.46752414
Step 6 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.68873787
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.34069622
Document:  Home Loan Application Guide -> 🤖 Score:  0.50065553
Step 9 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7148906
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.30480516
Document:  Home Loan Application Guide -> 🤖 Score:  0.52454984
Step 12 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.72614634
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.29255486
Document:  Home Loan Application Guide -> 🤖 Score:  0.5370023
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
TrainOutput(global_step=15, training_loss=0.009651281436261646, metrics={'train_runtime': 63.2486, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.237, 'total_flos': 0.0, 'train_loss': 0.009651281436261646, 'epoch': 5.0})

Nach der Feinabstimmung

Dieselbe Suche liefert jetzt viel klarere Ergebnisse:

  1. Dokument: Eröffnung eines NISA-Kontos (Punktzahl: 0,72) <- Viel sicherer
  2. Dokument: Opening a Regular Saving Account (Score: 0.28) <- Deutlich weniger relevant
  3. Dokument: Leitfaden für die Beantragung eines Hausdarlehens (Wert: 0,54)
get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913

Wenn Sie Ihr Modell in den Hugging Face Hub hochladen möchten, können Sie die Methode push_to_hub aus der Sentence Transformers-Bibliothek verwenden.

Wenn Sie Ihr Modell hochladen, können Sie direkt über den Hub darauf zugreifen, es mit anderen teilen und Ihre Arbeit versionieren. Nach dem Hochladen kann jeder Ihr Modell mit einer einzigen Codezeile laden, indem er einfach auf die eindeutige Modell-ID verweist <username>/my-embedding-gemma

# Push to Hub
model.push_to_hub("my-embedding-gemma")

Zusammenfassung und nächste Schritte

Sie haben jetzt gelernt, wie Sie ein EmbeddingGemma-Modell für eine bestimmte Domain anpassen, indem Sie es mit der Sentence Transformers-Bibliothek feinabstimmen.

Weitere Möglichkeiten mit EmbeddingGemma: