Ajusta EmbeddingGemma

Ver en ai.google.dev Ejecutar en Google Colab Ejecutar en Kaggle Abrir en Vertex AI Ver código fuente en GitHub

El ajuste fino ayuda a cerrar la brecha entre la comprensión de propósito general de un modelo y la precisión especializada y de alto rendimiento que requiere tu aplicación. Dado que ningún modelo es perfecto para todas las tareas, el ajuste lo adapta a tu dominio específico.

Imagina que tu empresa, "Shibuya Financial", ofrece varios productos financieros complejos, como fideicomisos de inversión, cuentas NISA (una cuenta de ahorro con ventajas fiscales) y préstamos para la vivienda. Tu equipo de asistencia al cliente usa una base de conocimiento interna para encontrar rápidamente respuestas a las preguntas de los clientes.

Configuración

Antes de comenzar este instructivo, completa los siguientes pasos:

  • Para acceder a EmbeddingGemma, ingresa a Hugging Face y selecciona Acknowledge license para un modelo de Gemma.
  • Genera un token de acceso de Hugging Face y úsalo para acceder desde Colab.

Este notebook se ejecutará en la CPU o la GPU.

Instala paquetes de Python

Instala las bibliotecas necesarias para ejecutar el modelo EmbeddingGemma y generar embeddings. Sentence Transformers es un framework de Python para incorporaciones de texto e imágenes. Para obtener más información, consulta la documentación de Sentence Transformers.

pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

Después de aceptar la licencia, necesitarás un token de Hugging Face válido para acceder al modelo.

# Login into Hugging Face Hub
from huggingface_hub import login
login()

Cargar modelo

Usa las bibliotecas de sentence-transformers para crear una instancia de una clase de modelo con EmbeddingGemma.

import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696

Prepara el conjunto de datos de ajuste

Esta es la parte más importante. Debes crear un conjunto de datos que le enseñe al modelo qué significa "similar" en tu contexto específico. Estos datos suelen estructurarse como tríos: (ancla, positivo, negativo).

  • Ancla: Es la oración o búsqueda original.
  • Positivo: Es una oración que es semánticamente muy similar o idéntica a la ancla.
  • Negativa: Es una oración que trata sobre un tema relacionado, pero que es semánticamente distinta.

En este ejemplo, solo preparamos 3 tríos, pero, para una aplicación real, necesitarías un conjunto de datos mucho más grande para que funcione bien.

from datasets import Dataset

dataset = [
    ["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
    ["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
    ["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
]

# Convert the list-based dataset into a list of dictionaries.
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]

# Create a Hugging Face `Dataset` object from the list of dictionaries.
train_dataset = Dataset.from_list(data_as_dicts)
print(train_dataset)
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 3
})

Antes del ajuste

Una búsqueda de "inversión libre de impuestos" podría haber arrojado los siguientes resultados, con sus respectivos índices de similitud:

  1. Documento: Apertura de una cuenta NISA (puntuación: 0.45)
  2. Documento: Opening a Regular Saving Account (Puntuación: 0.48) <- Puntuación similar, potencialmente confusa
  3. Documento: Guía de solicitud de préstamo para vivienda (puntuación: 0.42)
task_name = "STS"

def get_scores(query, documents):
  # Calculate embeddings by calling model.encode()
  query_embeddings = model.encode(query, prompt=task_name)
  doc_embeddings = model.encode(documents, prompt=task_name)

  # Calculate the embedding similarities
  similarities = model.similarity(query_embeddings, doc_embeddings)

  for idx, doc in enumerate(documents):
    print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])

query = "I want to start a tax-free installment investment, what should I do?"
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]

get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.45698774
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.48092696
Document:  Home Loan Application Guide -> 🤖 Score:  0.42127067

Capacitación

Con un framework como sentence-transformers en Python, el modelo básico aprende gradualmente las distinciones sutiles en tu vocabulario financiero.

from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback

loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="my-embedding-gemma",
    # Optional training parameters:
    prompts=model.prompts[task_name],    # use model's prompt to train
    num_train_epochs=5,
    per_device_train_batch_size=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    # Optional tracking/debugging parameters:
    logging_steps=train_dataset.num_rows,
    report_to="none",
)

class MyCallback(TrainerCallback):
    "A callback that evaluates the model at the end of eopch"
    def __init__(self, evaluate):
        self.evaluate = evaluate # evaluate function

    def on_log(self, args, state, control, **kwargs):
        # Evaluate the model using text generation
        print(f"Step {state.global_step} finished. Running evaluation:")
        self.evaluate()

def evaluate():
  get_scores(query, documents)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
    callbacks=[MyCallback(evaluate)]
)
trainer.train()
Step 3 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.6449194
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.44123
Document:  Home Loan Application Guide -> 🤖 Score:  0.46752414
Step 6 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.68873787
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.34069622
Document:  Home Loan Application Guide -> 🤖 Score:  0.50065553
Step 9 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7148906
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.30480516
Document:  Home Loan Application Guide -> 🤖 Score:  0.52454984
Step 12 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.72614634
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.29255486
Document:  Home Loan Application Guide -> 🤖 Score:  0.5370023
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
TrainOutput(global_step=15, training_loss=0.009651281436261646, metrics={'train_runtime': 63.2486, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.237, 'total_flos': 0.0, 'train_loss': 0.009651281436261646, 'epoch': 5.0})

Después del ajuste

La misma búsqueda ahora arroja resultados mucho más claros:

  1. Documento: Opening a NISA account (Apertura de una cuenta NISA; Puntuación: 0.72) <- Mucho más seguro
  2. Documento: Opening a Regular Saving Account (Puntuación: 0.28) <- Claramente menos relevante
  3. Documento: Guía de solicitud de préstamo hipotecario (puntuación: 0.54)
get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913

Para subir tu modelo a Hugging Face Hub, puedes usar el método push_to_hub de la biblioteca de Sentence Transformers.

Subir tu modelo facilita el acceso a la inferencia directamente desde el Hub, compartirlo con otras personas y controlar las versiones de tu trabajo. Una vez que se sube, cualquier persona puede cargar tu modelo con una sola línea de código, simplemente haciendo referencia a su ID de modelo único <username>/my-embedding-gemma.

# Push to Hub
model.push_to_hub("my-embedding-gemma")

Resumen y próximos pasos

Ahora aprendiste a adaptar un modelo de EmbeddingGemma para un dominio específico ajustándolo con la biblioteca de Sentence Transformers.

Explora qué más puedes hacer con EmbeddingGemma: