使用 Sentence Transformers 生成嵌入

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Kaggle 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

EmbeddingGemma 是一款轻量级开放式嵌入模型,旨在在手机等日常设备上实现快速、高质量的检索。该模型只有 3.08 亿个参数,非常高效,可以直接在本地计算机上运行高级 AI 技术(例如检索增强生成 [RAG]),无需连接互联网。

设置

在开始本教程之前,请完成以下步骤:

  • 登录 Hugging Face 并为某个 Gemma 模型选择确认许可,即可获取对 Gemma 的访问权限。
  • 生成 Hugging Face 访问令牌,并使用该令牌从 Colab 登录。

此笔记本将在 CPU 或 GPU 上运行。

安装 Python 软件包

安装运行 EmbeddingGemma 模型和生成嵌入所需的库。Sentence Transformers 是一个用于文本和图片嵌入的 Python 框架。如需了解详情,请参阅 Sentence Transformers 文档。

pip install -U sentence-transformers git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview

接受许可后,您需要有效的 Hugging Face 令牌才能访问模型。

# Login into Hugging Face Hub
from huggingface_hub import login
login()

加载模型

使用 sentence-transformers 库创建具有 EmbeddingGemma 的模型类的实例。

import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "google/embeddinggemma-300M"
model = SentenceTransformer(model_id).to(device=device)

print(f"Device: {model.device}")
print(model)
print("Total number of parameters in the model:", sum([p.numel() for _, p in model.named_parameters()]))
Device: cuda:0
SentenceTransformer(
  (0): Transformer({'max_seq_length': 2048, 'do_lower_case': False, 'architecture': 'Gemma3TextModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 3072, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Dense({'in_features': 3072, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (4): Normalize()
)
Total number of parameters in the model: 307581696

生成嵌入

嵌入是文本(例如字词或句子)的数值表示形式,用于捕获其语义含义。从本质上讲,它是一个数字列表(一个向量),可让计算机了解字词之间的关系和上下文。

我们来看看 EmbeddingGemma 如何处理三个不同的字词 ["apple", "banana", "car"]

EmbeddingGemma 经过大量文本训练,已学习字词和概念之间的关系。

words = ["apple", "banana", "car"]

# Calculate embeddings by calling model.encode()
embeddings = model.encode(words)

print(embeddings)
for idx, embedding in enumerate(embeddings):
  print(f"Embedding {idx+1} (shape): {embedding.shape}")
[[-0.18476306  0.00167681  0.03773484 ... -0.07996225 -0.02348064
   0.00976741]
 [-0.21189538 -0.02657359  0.02513712 ... -0.08042689 -0.01999852
   0.00512146]
 [-0.18924113 -0.02551468  0.04486253 ... -0.06377774 -0.03699806
   0.03973572]]
Embedding 1: (768,)
Embedding 2: (768,)
Embedding 3: (768,)

模型会为每个句子输出一个数值向量。实际向量非常长 (768),但为简单起见,此处仅显示了几个维度。

关键不在于各个数字本身,而在于向量之间的距离。如果我们要在多维空间中绘制这些向量,applebanana 的向量会非常接近。而 car 的向量会与其他两个向量相距很远。

确定相似度

在本部分中,我们将使用嵌入来确定不同句子在语义上的相似程度。下面展示了高、中、低相似度得分的示例。

  • 高相似度:

    • 句子 A:“厨师为客人准备了一顿美味的饭菜。”
    • 句子 B:“厨师为来访者烹制了一顿美味的晚餐。”
    • 推理:这两个句子使用不同的词语和语法结构(主动语态与被动语态)描述了同一事件。它们传达的核心含义相同。
  • 中等相似度:

    • 句子 A:“她是机器学习方面的专家。”
    • 句子 B:“他对人工智能有着浓厚的兴趣。”
    • 推理:这两个句子相关,因为机器学习是人工智能的一个子领域。不过,它们所指的受众群体不同,参与度也不同(专家级 vs. 感兴趣)。
  • 相似度较低:

    • 句子 A:“东京今天天气晴朗。”
    • 句子 B:“我需要购买本周的杂货。”
    • 推理:这两个句子涉及完全无关的主题,没有语义重叠。
# The sentences to encode
sentence_high = [
    "The chef prepared a delicious meal for the guests.",
    "A tasty dinner was cooked by the chef for the visitors."
]
sentence_medium = [
    "She is an expert in machine learning.",
    "He has a deep interest in artificial intelligence."
]
sentence_low = [
    "The weather in Tokyo is sunny today.",
    "I need to buy groceries for the week."
]

for sentence in [sentence_high, sentence_medium, sentence_low]:
  print("🙋‍♂️")
  print(sentence)
  embeddings = model.encode(sentence)
  similarities = model.similarity(embeddings[0], embeddings[1])
  print("`-> 🤖 score: ", similarities.numpy()[0][0])
🙋‍♂️
['The chef prepared a delicious meal for the guests.', 'A tasty dinner was cooked by the chef for the visitors.']
`-> 🤖 score:  0.8002148
🙋‍♂️
['She is an expert in machine learning.', 'He has a deep interest in artificial intelligence.']
`-> 🤖 score:  0.45417833
🙋‍♂️
['The weather in Tokyo is sunny today.', 'I need to buy groceries for the week.']
`-> 🤖 score:  0.22262995

将提示与 EmbeddingGemma 搭配使用

如需使用 EmbeddingGemma 生成最佳嵌入,您应在输入文本的开头添加“指令提示”或“任务”。这些提示可针对特定任务(例如文档检索或问答)优化嵌入,并帮助模型区分不同的输入类型,例如搜索查询与文档。

如何应用提示

您可以通过三种方式在推理期间应用提示。

  1. 使用 prompt 实参
    将完整的提示字符串直接传递给 encode 方法。这样一来,您就可以精确控制。

    embeddings = model.encode(
        sentence,
        prompt="task: sentence similarity | query: "
    )
    
  2. 使用 prompt_name 实参
    按名称选择预定义提示。这些提示是从模型的配置中加载的,或者是在模型初始化期间加载的。

    embeddings = model.encode(sentence, prompt_name="STS")
    
  3. 使用默认提示
    如果您未指定 promptprompt_name,系统会自动使用设置为 default_prompt_name 的提示;如果未设置默认提示,则不会应用任何提示。

    embeddings = model.encode(sentence)
    
print("Available tasks:")
for name, prefix in model.prompts.items():
  print(f" {name}: \"{prefix}\"")
print("-"*80)

for sentence in [sentence_high, sentence_medium, sentence_low]:
  print("🙋‍♂️")
  print(sentence)
  embeddings = model.encode(sentence, prompt_name="STS")
  similarities = model.similarity(embeddings[0], embeddings[1])
  print("`-> 🤖 score: ", similarities.numpy()[0][0])
Available tasks:
 query: "task: search result | query: "
 document: "title: none | text: "
 BitextMining: "task: search result | query: "
 Clustering: "task: clustering | query: "
 Classification: "task: classification | query: "
 InstructionRetrieval: "task: code retrieval | query: "
 MultilabelClassification: "task: classification | query: "
 PairClassification: "task: sentence similarity | query: "
 Reranking: "task: search result | query: "
 Retrieval: "task: search result | query: "
 Retrieval-query: "task: search result | query: "
 Retrieval-document: "title: none | text: "
 STS: "task: sentence similarity | query: "
 Summarization: "task: summarization | query: "
--------------------------------------------------------------------------------
🙋‍♂️
['The chef prepared a delicious meal for the guests.', 'A tasty dinner was cooked by the chef for the visitors.']
`-> 🤖 score:  0.9363755
🙋‍♂️
['She is an expert in machine learning.', 'He has a deep interest in artificial intelligence.']
`-> 🤖 score:  0.6425841
🙋‍♂️
['The weather in Tokyo is sunny today.', 'I need to buy groceries for the week.']
`-> 🤖 score:  0.38587403

使用场景:检索增强生成 (RAG)

对于 RAG 系统,请使用以下 prompt_name 值为查询和文档创建专用嵌入:

  • 对于查询:请使用 prompt_name="Retrieval-query"

    query_embedding = model.encode(
        "How do I use prompts with this model?",
        prompt_name="Retrieval-query"
    )
    
  • 对于文档:请使用 prompt_name="Retrieval-document"。为了进一步改进文档嵌入,您还可以使用 prompt 实参直接添加标题:

    • 带有标题
    doc_embedding = model.encode(
        "The document text...",
        prompt="title: Using Prompts in RAG | text: "
    )
    
    • 没有标题
    doc_embedding = model.encode(
        "The document text...",
        prompt="title: none | text: "
    )
    

延伸阅读

分类

分类任务是指将一段文本分配给一个或多个预定义的类别或标签。这是自然语言处理 (NLP) 中最基本的任务之一。

文本分类的一个实际应用是客户支持服务工单路由。此流程会自动将客户查询转到正确的部门,从而节省时间并减少人工工作量。

labels = ["Billing Issue", "Technical Support", "Sales Inquiry"]

sentence = [
  "Excuse me, the app freezes on the login screen. It won't work even when I try to reset my password.",
  "I would like to inquire about your enterprise plan pricing and features for a team of 50 people.",
]

# Calculate embeddings by calling model.encode()
label_embeddings = model.encode(labels, prompt_name="Classification")
embeddings = model.encode(sentence, prompt_name="Classification")

# Calculate the embedding similarities
similarities = model.similarity(embeddings, label_embeddings)
print(similarities)

idx = similarities.argmax(1)
print(idx)

for example in sentence:
  print("🙋‍♂️", example, "-> 🤖", labels[idx[sentence.index(example)]])
tensor([[0.4673, 0.5145, 0.3604],
        [0.4191, 0.5010, 0.5966]])
tensor([1, 2])
🙋‍♂️ Excuse me, the app freezes on the login screen. It won't work even when I try to reset my password. -> 🤖 Technical Support
🙋‍♂️ I would like to inquire about your enterprise plan pricing and features for a team of 50 people. -> 🤖 Sales Inquiry

Matryoshka 表征学习 (MRL)

EmbeddingGemma 利用 MRL 通过一个模型提供多种嵌入大小。这是一种巧妙的训练方法,可创建单个高质量嵌入,其中最重要的信息集中在向量的开头。

这意味着,您只需取完整嵌入的前 N 个维度,即可获得一个较小但仍非常实用的嵌入。使用较小且截断的嵌入向量可显著降低存储成本并加快处理速度,但这种效率是以可能降低嵌入向量质量为代价的。借助 MRL,您可以根据应用的具体需求,在速度和准确率之间选择最佳平衡点。

我们使用三个字词 ["apple", "banana", "car"] 并创建简化的嵌入,以了解 MRL 的运作方式。

def check_word_similarities():
  # Calculate the embedding similarities
  print("similarity function: ", model.similarity_fn_name)
  similarities = model.similarity(embeddings[0], embeddings[1:])
  print(similarities)

  for idx, word in enumerate(words[1:]):
    print("🙋‍♂️ apple vs.", word, "-> 🤖 score: ", similarities.numpy()[0][idx])

# Calculate embeddings by calling model.encode()
embeddings = model.encode(words, prompt_name="STS")

check_word_similarities()
similarity function:  cosine
tensor([[0.7510, 0.6685]])
🙋‍♂️ apple vs. banana -> 🤖 score:  0.75102395
🙋‍♂️ apple vs. car -> 🤖 score:  0.6684626

现在,您无需使用新模型即可更快地运行应用。只需将完整嵌入截断为前 512 个维度即可。为获得最佳效果,还建议设置 normalize_embeddings=True,将向量缩放到单位长度 1。

embeddings = model.encode(words, truncate_dim=512, normalize_embeddings=True)

for idx, embedding in enumerate(embeddings):
  print(f"Embedding {idx+1}: {embedding.shape}")

print("-"*80)
check_word_similarities()
Embedding 1: (512,)
Embedding 2: (512,)
Embedding 3: (512,)
--------------------------------------------------------------------------------
similarity function:  cosine
tensor([[0.7674, 0.7041]])
🙋‍♂️ apple vs. banana -> 🤖 score:  0.767427
🙋‍♂️ apple vs. car -> 🤖 score:  0.7040509

在资源极其受限的环境中,您可以进一步缩短嵌入向量,使其仅包含 256 个维度。您还可以使用更高效的点积进行相似度计算,而不是使用标准的余弦相似度。

model = SentenceTransformer(model_id, truncate_dim=256, similarity_fn_name="dot").to(device=device)
embeddings = model.encode(words, prompt_name="STS", normalize_embeddings=True)

for idx, embedding in enumerate(embeddings):
  print(f"Embedding {idx+1}: {embedding.shape}")

print("-"*80)
check_word_similarities()
Embedding 1: (256,)
Embedding 2: (256,)
Embedding 3: (256,)
--------------------------------------------------------------------------------
similarity function:  dot
tensor([[0.7855, 0.7382]])
🙋‍♂️ apple vs. banana -> 🤖 score:  0.7854644
🙋‍♂️ apple vs. car -> 🤖 score:  0.7382126

总结与后续步骤

现在,您可以使用 EmbeddingGemma 和 Sentence Transformers 库生成高质量的文本嵌入了。运用这些技能来构建强大的功能,例如语义相似度、文本分类和检索增强生成 (RAG) 系统,并继续探索 Gemma 模型可实现的功能。

接下来,请参阅以下文档: