微调 EmbeddingGemma

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Kaggle 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

微调有助于弥合模型的一般用途理解能力与应用所需的专业化、高性能准确性之间的差距。由于没有哪种模型能完美胜任所有任务,因此您可以通过微调来使其适应特定领域。

假设您的公司“Shibuya Financial”提供各种复杂的金融产品,例如投资信托、NISA 账户(一种具有税收优惠的储蓄账户)和住房贷款。您的客户支持团队使用内部知识库快速查找客户问题的答案。

设置

在开始本教程之前,请完成以下步骤:

  • 登录 Hugging Face 并为某个 Gemma 模型选择确认许可,即可获取 EmbeddingGemma 的访问权限。
  • 生成 Hugging Face 访问令牌,并使用该令牌从 Colab 登录。

此笔记本将在 CPU 或 GPU 上运行。

安装 Python 软件包

安装运行 EmbeddingGemma 模型和生成嵌入所需的库。Sentence Transformers 是一个用于文本和图片嵌入的 Python 框架。如需了解详情,请参阅 Sentence Transformers 文档。

pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

接受许可后,您需要有效的 Hugging Face 令牌才能访问模型。

# Login into Hugging Face Hub
from huggingface_hub import login
login()

加载模型

使用 sentence-transformers 库创建具有 EmbeddingGemma 的模型类的实例。

import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696

准备微调数据集

这是最关键的部分。您需要创建一个数据集,让模型了解在您的特定情境中“相似”的含义。此类数据通常以三元组的形式呈现:(锚点、正例、负例)

  • 锚点:原始查询或句子。
  • 正例:在语义上与锚点非常相似或完全相同的句子。
  • 负例:句子涉及相关主题,但在语义上有所不同。

在此示例中,我们仅准备了 3 个三元组,但在实际应用中,您需要更大的数据集才能获得良好的效果。

from datasets import Dataset

dataset = [
    ["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
    ["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
    ["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
]

# Convert the list-based dataset into a list of dictionaries.
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]

# Create a Hugging Face `Dataset` object from the list of dictionaries.
train_dataset = Dataset.from_list(data_as_dicts)
print(train_dataset)
Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 3
})

微调之前

搜索“免税投资”可能会得到以下结果,并附有相似度得分:

  1. 文档:开设 NISA 账户(得分:0.45)
  2. 文档:开设常规储蓄账户(得分:0.48)<- 得分相近,可能会造成混淆
  3. 文档:住房贷款申请指南(得分:0.42)
task_name = "STS"

def get_scores(query, documents):
  # Calculate embeddings by calling model.encode()
  query_embeddings = model.encode(query, prompt=task_name)
  doc_embeddings = model.encode(documents, prompt=task_name)

  # Calculate the embedding similarities
  similarities = model.similarity(query_embeddings, doc_embeddings)

  for idx, doc in enumerate(documents):
    print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])

query = "I want to start a tax-free installment investment, what should I do?"
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]

get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.45698774
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.48092696
Document:  Home Loan Application Guide -> 🤖 Score:  0.42127067

培训

借助 Python 中的 sentence-transformers 等框架,基础模型可以逐渐学习金融词汇中的细微差别。

from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback

loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="my-embedding-gemma",
    # Optional training parameters:
    prompts=model.prompts[task_name],    # use model's prompt to train
    num_train_epochs=5,
    per_device_train_batch_size=1,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    # Optional tracking/debugging parameters:
    logging_steps=train_dataset.num_rows,
    report_to="none",
)

class MyCallback(TrainerCallback):
    "A callback that evaluates the model at the end of eopch"
    def __init__(self, evaluate):
        self.evaluate = evaluate # evaluate function

    def on_log(self, args, state, control, **kwargs):
        # Evaluate the model using text generation
        print(f"Step {state.global_step} finished. Running evaluation:")
        self.evaluate()

def evaluate():
  get_scores(query, documents)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
    callbacks=[MyCallback(evaluate)]
)
trainer.train()
Step 3 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.6449194
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.44123
Document:  Home Loan Application Guide -> 🤖 Score:  0.46752414
Step 6 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.68873787
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.34069622
Document:  Home Loan Application Guide -> 🤖 Score:  0.50065553
Step 9 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7148906
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.30480516
Document:  Home Loan Application Guide -> 🤖 Score:  0.52454984
Step 12 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.72614634
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.29255486
Document:  Home Loan Application Guide -> 🤖 Score:  0.5370023
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
Step 15 finished. Running evaluation:
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913
TrainOutput(global_step=15, training_loss=0.009651281436261646, metrics={'train_runtime': 63.2486, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.237, 'total_flos': 0.0, 'train_loss': 0.009651281436261646, 'epoch': 5.0})

微调后

现在,同样的搜索会产生更清晰的结果:

  1. Document: Opening a NISA account (Score: 0.72) <- Much more confident
  2. 文档:开设常规储蓄账户(得分:0.28)<- 相关性明显较低
  3. 文档:住房贷款申请指南(得分:0.54)
get_scores(query, documents)
Document:  Opening a NISA Account -> 🤖 Score:  0.7294032
Document:  Opening a Regular Savings Account -> 🤖 Score:  0.2893038
Document:  Home Loan Application Guide -> 🤖 Score:  0.54087913

如需将模型上传到 Hugging Face Hub,您可以使用 Sentence Transformers 库中的 push_to_hub 方法。

上传模型后,您可以轻松地直接从 Hub 访问模型以进行推理、与他人分享模型,以及对自己的工作进行版本控制。上传后,任何人只需引用其唯一的模型 ID <username>/my-embedding-gemma,即可通过一行代码加载您的模型

# Push to Hub
model.push_to_hub("my-embedding-gemma")

总结与后续步骤

您现在已了解如何通过使用 Sentence Transformers 库微调 EmbeddingGemma 模型,使其适应特定领域。

探索 EmbeddingGemma 的更多用途: