![]() |
![]() |
![]() |
|
![]() |
微调有助于弥合模型的一般用途理解能力与应用所需的专业化、高性能准确性之间的差距。由于没有哪种模型能完美胜任所有任务,因此您可以通过微调来使其适应特定领域。
假设您的公司“Shibuya Financial”提供各种复杂的金融产品,例如投资信托、NISA 账户(一种具有税收优惠的储蓄账户)和住房贷款。您的客户支持团队使用内部知识库快速查找客户问题的答案。
设置
在开始本教程之前,请完成以下步骤:
- 登录 Hugging Face 并为某个 Gemma 模型选择确认许可,即可获取 EmbeddingGemma 的访问权限。
- 生成 Hugging Face 访问令牌,并使用该令牌从 Colab 登录。
此笔记本将在 CPU 或 GPU 上运行。
安装 Python 软件包
安装运行 EmbeddingGemma 模型和生成嵌入所需的库。Sentence Transformers 是一个用于文本和图片嵌入的 Python 框架。如需了解详情,请参阅 Sentence Transformers 文档。
pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
接受许可后,您需要有效的 Hugging Face 令牌才能访问模型。
# Login into Hugging Face Hub
from huggingface_hub import login
login()
加载模型
使用 sentence-transformers
库创建具有 EmbeddingGemma 的模型类的实例。
import torch
from sentence_transformers import SentenceTransformer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)
print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0 SentenceTransformer( (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'}) (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True}) (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'}) (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'}) (4): Normalize() ) Total number of parameters in the model: 307581696
准备微调数据集
这是最关键的部分。您需要创建一个数据集,让模型了解在您的特定情境中“相似”的含义。此类数据通常以三元组的形式呈现:(锚点、正例、负例)
- 锚点:原始查询或句子。
- 正例:在语义上与锚点非常相似或完全相同的句子。
- 负例:句子涉及相关主题,但在语义上有所不同。
在此示例中,我们仅准备了 3 个三元组,但在实际应用中,您需要更大的数据集才能获得良好的效果。
from datasets import Dataset
dataset = [
["How do I open a NISA account?", "What is the procedure for starting a new tax-free investment account?", "I want to check the balance of my regular savings account."],
["Are there fees for making an early repayment on a home loan?", "If I pay back my house loan early, will there be any costs?", "What is the management fee for this investment trust?"],
["What is the coverage for medical insurance?", "Tell me about the benefits of the health insurance plan.", "What is the cancellation policy for my life insurance?"],
]
# Convert the list-based dataset into a list of dictionaries.
data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ]
# Create a Hugging Face `Dataset` object from the list of dictionaries.
train_dataset = Dataset.from_list(data_as_dicts)
print(train_dataset)
Dataset({ features: ['anchor', 'positive', 'negative'], num_rows: 3 })
微调之前
搜索“免税投资”可能会得到以下结果,并附有相似度得分:
- 文档:开设 NISA 账户(得分:0.45)
- 文档:开设常规储蓄账户(得分:0.48)<- 得分相近,可能会造成混淆
- 文档:住房贷款申请指南(得分:0.42)
task_name = "STS"
def get_scores(query, documents):
# Calculate embeddings by calling model.encode()
query_embeddings = model.encode(query, prompt=task_name)
doc_embeddings = model.encode(documents, prompt=task_name)
# Calculate the embedding similarities
similarities = model.similarity(query_embeddings, doc_embeddings)
for idx, doc in enumerate(documents):
print("Document: ", doc, "-> 🤖 Score: ", similarities.numpy()[0][idx])
query = "I want to start a tax-free installment investment, what should I do?"
documents = ["Opening a NISA Account", "Opening a Regular Savings Account", "Home Loan Application Guide"]
get_scores(query, documents)
Document: Opening a NISA Account -> 🤖 Score: 0.45698774 Document: Opening a Regular Savings Account -> 🤖 Score: 0.48092696 Document: Home Loan Application Guide -> 🤖 Score: 0.42127067
培训
借助 Python 中的 sentence-transformers
等框架,基础模型可以逐渐学习金融词汇中的细微差别。
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="my-embedding-gemma",
# Optional training parameters:
prompts=model.prompts[task_name], # use model's prompt to train
num_train_epochs=5,
per_device_train_batch_size=1,
learning_rate=2e-5,
warmup_ratio=0.1,
# Optional tracking/debugging parameters:
logging_steps=train_dataset.num_rows,
report_to="none",
)
class MyCallback(TrainerCallback):
"A callback that evaluates the model at the end of eopch"
def __init__(self, evaluate):
self.evaluate = evaluate # evaluate function
def on_log(self, args, state, control, **kwargs):
# Evaluate the model using text generation
print(f"Step {state.global_step} finished. Running evaluation:")
self.evaluate()
def evaluate():
get_scores(query, documents)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
callbacks=[MyCallback(evaluate)]
)
trainer.train()
Step 3 finished. Running evaluation: Document: Opening a NISA Account -> 🤖 Score: 0.6449194 Document: Opening a Regular Savings Account -> 🤖 Score: 0.44123 Document: Home Loan Application Guide -> 🤖 Score: 0.46752414 Step 6 finished. Running evaluation: Document: Opening a NISA Account -> 🤖 Score: 0.68873787 Document: Opening a Regular Savings Account -> 🤖 Score: 0.34069622 Document: Home Loan Application Guide -> 🤖 Score: 0.50065553 Step 9 finished. Running evaluation: Document: Opening a NISA Account -> 🤖 Score: 0.7148906 Document: Opening a Regular Savings Account -> 🤖 Score: 0.30480516 Document: Home Loan Application Guide -> 🤖 Score: 0.52454984 Step 12 finished. Running evaluation: Document: Opening a NISA Account -> 🤖 Score: 0.72614634 Document: Opening a Regular Savings Account -> 🤖 Score: 0.29255486 Document: Home Loan Application Guide -> 🤖 Score: 0.5370023 Step 15 finished. Running evaluation: Document: Opening a NISA Account -> 🤖 Score: 0.7294032 Document: Opening a Regular Savings Account -> 🤖 Score: 0.2893038 Document: Home Loan Application Guide -> 🤖 Score: 0.54087913 Step 15 finished. Running evaluation: Document: Opening a NISA Account -> 🤖 Score: 0.7294032 Document: Opening a Regular Savings Account -> 🤖 Score: 0.2893038 Document: Home Loan Application Guide -> 🤖 Score: 0.54087913 TrainOutput(global_step=15, training_loss=0.009651281436261646, metrics={'train_runtime': 63.2486, 'train_samples_per_second': 0.237, 'train_steps_per_second': 0.237, 'total_flos': 0.0, 'train_loss': 0.009651281436261646, 'epoch': 5.0})
微调后
现在,同样的搜索会产生更清晰的结果:
- Document: Opening a NISA account (Score: 0.72) <- Much more confident
- 文档:开设常规储蓄账户(得分:0.28)<- 相关性明显较低
- 文档:住房贷款申请指南(得分:0.54)
get_scores(query, documents)
Document: Opening a NISA Account -> 🤖 Score: 0.7294032 Document: Opening a Regular Savings Account -> 🤖 Score: 0.2893038 Document: Home Loan Application Guide -> 🤖 Score: 0.54087913
如需将模型上传到 Hugging Face Hub,您可以使用 Sentence Transformers 库中的 push_to_hub
方法。
上传模型后,您可以轻松地直接从 Hub 访问模型以进行推理、与他人分享模型,以及对自己的工作进行版本控制。上传后,任何人只需引用其唯一的模型 ID <username>/my-embedding-gemma
,即可通过一行代码加载您的模型
# Push to Hub
model.push_to_hub("my-embedding-gemma")
总结与后续步骤
您现在已了解如何通过使用 Sentence Transformers 库微调 EmbeddingGemma 模型,使其适应特定领域。
探索 EmbeddingGemma 的更多用途:
- Sentence Transformers 文档中的训练概览
- 使用 Sentence Transformers 生成嵌入
- Gemma 实战宝典中的简单 RAG 示例