Dostrajanie modelu EmbeddingGemma

Wyświetl na ai.google.dev Uruchom w Google Colab Uruchom w Kaggle Otwórz w Vertex AI Wyświetl źródło w GitHubie

Dostrajanie pomaga zmniejszyć różnicę między ogólnym zrozumieniem modelu a wysoką dokładnością, której wymaga Twoja aplikacja. Żaden model nie jest idealny do każdego zadania, dlatego dostrajanie pozwala dostosować go do konkretnej domeny.

Wyobraź sobie, że Twoja firma „Shibuya Financial” oferuje różne złożone produkty finansowe, takie jak fundusze inwestycyjne, konta NISA (konto oszczędnościowe z korzyściami podatkowymi) i kredyty hipoteczne. Twój zespół obsługi klienta korzysta z wewnętrznej bazy wiedzy, aby szybko znajdować odpowiedzi na pytania klientów.

Konfiguracja

Zanim rozpoczniesz wykonywanie zadań z tego samouczka, wykonaj te czynności:

  • Aby uzyskać dostęp do EmbeddingGemma, zaloguj się w Hugging Face i wybierz Potwierdź licencję dla modelu Gemma.
  • Wygeneruj token dostępu Hugging Face i użyj go, aby zalogować się w Colab.

Ten notatnik będzie działać na procesorze lub GPU.

Instalowanie pakietów Pythona

Zainstaluj biblioteki wymagane do uruchomienia modelu EmbeddingGemma i generowania wektorów. Sentence Transformers to platforma Pythona do tworzenia reprezentacji właściwościowych tekstu i obrazów. Więcej informacji znajdziesz w dokumentacji Sentence Transformers.

pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

Po zaakceptowaniu licencji musisz mieć ważny token Hugging Face, aby uzyskać dostęp do modelu.

# Login into Hugging Face Hub
from huggingface_hub import login
login()

Wczytaj model

Użyj bibliotek sentence-transformers, aby utworzyć instancję klasy modelu z EmbeddingGemma.

import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696

Przygotowywanie zbioru danych do dostrajania

To najważniejsza część. Musisz utworzyć zbiór danych, który nauczy model, co w Twoim konkretnym kontekście oznacza „podobny”. Dane są często uporządkowane w formie trójek: (kotwica, pozytyw, negatyw).

  • Anchor: oryginalne zapytanie lub zdanie.
  • Pozytywny: zdanie, które jest semantycznie bardzo podobne lub identyczne z tekstem zakotwiczenia.
  • Negatywny: zdanie, które dotyczy powiązanego tematu, ale jest semantycznie odrębne.

W tym przykładzie przygotowaliśmy tylko 3 triplety, ale w przypadku prawdziwej aplikacji potrzebny byłby znacznie większy zbiór danych, aby uzyskać dobre wyniki.

from datasets import Dataset

dataset = [
    ["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
    ["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
    ["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
]

# Convert the list-based dataset into a list of dictionaries.
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]

# Create a Hugging Face `Dataset` object from the list of dictionaries.
train_dataset = Dataset.from_list(data_as_dicts)
print(train_dataset)
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 3
})

Przed dostrajaniem

Wyszukiwanie hasła „inwestycja zwolniona z podatku” mogło dać te wyniki z ocenami podobieństwa:

  1. Dokument: Opening a NISA account (Wynik: 0,45)
  2. Dokument: Otwieranie zwykłego konta oszczędnościowego (wynik: 0,48) <- Podobny wynik, może być mylący
  3. Dokument: Home Loan Application Guide (wynik: 0,42)
task_name = "STS"

def get_scores(query, documents):
  # Calculate embeddings by calling model.encode()
  query_embeddings = model.encode(query, prompt=task_name)
  doc_embeddings = model.encode(documents, prompt=task_name)

  # Calculate the embedding similarities
  similarities = model.similarity(query_embeddings, doc_embeddings)

  for idx, doc in enumerate(documents):
    print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])

query = "I want to start a tax-free installment investment, what should I do?"
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]

get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.45698774
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.48092696
Document:  Home Loan Application Guide -> 🤖 Score:  0.42127067

Szkolenia

Korzystając z platformy takiej jak sentence-transformers w Pythonie, model podstawowy stopniowo uczy się subtelnych różnic w Twoim słownictwie finansowym.

from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback

loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="my-embedding-gemma",
    # Optional training parameters:
    prompts=model.prompts[task_name],    # use model's prompt to train
    num_train_epochs=5,
    per_device_train_batch_size=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    # Optional tracking/debugging parameters:
    logging_steps=train_dataset.num_rows,
    report_to="none",
)

class MyCallback(TrainerCallback):
    "A callback that evaluates the model at the end of eopch"
    def __init__(self, evaluate):
        self.evaluate = evaluate # evaluate function

    def on_log(self, args, state, control, **kwargs):
        # Evaluate the model using text generation
        print(f"Step {state.global_step} finished. Running evaluation:")
        self.evaluate()

def evaluate():
  get_scores(query, documents)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
    callbacks=[MyCallback(evaluate)]
)
trainer.train()
Step 3 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.6449194
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.44123
Document:  Home Loan Application Guide -> 🤖 Score:  0.46752414
Step 6 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.68873787
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.34069622
Document:  Home Loan Application Guide -> 🤖 Score:  0.50065553
Step 9 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7148906
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.30480516
Document:  Home Loan Application Guide -> 🤖 Score:  0.52454984
Step 12 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.72614634
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.29255486
Document:  Home Loan Application Guide -> 🤖 Score:  0.5370023
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
TrainOutput(global_step=15, training_loss=0.009651281436261646, metrics={'train_runtime': 63.2486, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.237, 'total_flos': 0.0, 'train_loss': 0.009651281436261646, 'epoch': 5.0})

Po dostrajaniu

To samo wyszukiwanie daje teraz znacznie bardziej przejrzyste wyniki:

  1. Dokument: Opening a NISA account (Score: 0.72) <- Znacznie większa pewność
  2. Dokument: Opening a Regular Saving Account (wynik: 0,28) <- Wyraźnie mniej trafny
  3. Dokument: Przewodnik po wnioskowaniu o kredyt hipoteczny (wynik: 0,54)
get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913

Aby przesłać model do centrum Hugging Face, możesz użyć metody push_to_hub z biblioteki Sentence Transformers.

Przesłanie modelu ułatwia dostęp do niego na potrzeby wnioskowania bezpośrednio z platformy, udostępnianie go innym i wersjonowanie pracy. Po przesłaniu każdy może załadować Twój model za pomocą jednego wiersza kodu, po prostu odwołując się do jego unikalnego identyfikatora modelu <username>/my-embedding-gemma

# Push to Hub
model.push_to_hub("my-embedding-gemma")

Podsumowanie i dalsze kroki

Wiesz już, jak dostosować model EmbeddingGemma do konkretnej domeny, precyzyjnie go dostrajając za pomocą biblioteki Sentence Transformers.

Sprawdź, co jeszcze możesz zrobić dzięki EmbeddingGemma: