这段代码的核心是将文本转换为标准化的语义嵌入向量,专为知识图谱构建设计

1. Sentence Transformer的工作原理

Sentence_Transformer完全按照Sentence Transformer的工作原理设计:

  • 使用AutoModel.from_pretrained加载BERT模型
  • mean_pooling方法将token embeddings平均池化为句子嵌入
  • F.normalize进行L2归一化,使向量长度为1

这与传统的Word2Vec不同,Sentence Transformer能捕捉更丰富的语义信息。

2. 文本处理流程

sbert_text2embedding函数的处理流程:

  1. 使用tokenizer将文本转换为token ids和attention mask(注意力掩码)
  2. 将数据组织成PyTorch Dataset和DataLoader
  3. 通过模型前向传播得到token embeddings
  4. 使用mean_pooling得到句子嵌入
  5. 进行归一化

一、核心参数定义

pretrained_repo = 'sentence-transformers/all-roberta-large-v1'
  • 功能:指定预训练模型的仓库地址(来自 Hugging Face Hub)
  • 细节解析
    • sentence-transformers:专门优化句子嵌入的库,基于 Hugging Face Transformers 封装,适合语义相似性、文本匹配任务
    • all-roberta-large-v1:模型基座是 RoBERTa(优化版 BERT),"large" 表示模型规模(隐藏层维度 1024),专门为句子级嵌入训练,知识图谱中用于将实体描述、关系描述等文本转换为语义向量(后续用于实体匹配、关系抽取等)
  • 作用:确定使用的预训练权重,避免从零训练,直接复用成熟的语义表示能力
batch_size = 256  # Adjust the batch size as needed
  • 功能:定义批量推理 / 训练的样本数量
  • 细节解析
    • 批次大小平衡「速度」和「内存」:大批次(如 256)推理更快,但占用 GPU/CPU 内存更多;小批次(如 32)内存占用少,但速度慢
    • 注释提示 "根据需求调整":需结合硬件配置(如 GPU 显存大小)修改,显存不足时减小数值
  • 作用:批量处理文本时控制单次输入模型的样本数,提升效率并避免内存溢出

二、数据集封装类(PyTorch 标准 Dataset)

class Dataset(torch.utils.data.Dataset):
  • 功能:继承 PyTorch 的Dataset基类,封装模型输入数据(input_idsattention_mask
  • 核心意义:PyTorch 的DataLoader需配合Dataset使用,实现自动批量加载、并行处理,尤其适合大规模文本(知识图谱中可能有上万条实体文本)
def __init__(self, input_ids=None, attention_mask=None):
    super().__init__()
    self.data = {
        "input_ids": input_ids,
        "att_mask": attention_mask,
    }
  • 功能:初始化数据集,接收 tokenizer 编码后的核心输入
  • 参数解析
    • input_ids:文本经 tokenizer 分词后,映射到词汇表的整数 ID(形状:[样本数,序列长度])
    • attention_mask:注意力掩码(0 表示无效 token,1 表示有效 token),用于告诉模型哪些字符需要关注(如 padding 补全的字符无需关注)
  • 逻辑:用字典self.data统一存储两个核心输入,方便后续按索引提取
def __len__(self):
    return self.data["input_ids"].size(0)
  • 功能:返回数据集的总样本数
  • 细节input_ids.size(0)取张量的第一个维度(样本数维度),因为input_ids的形状是「样本数 × 序列长度」
  • 作用DataLoader通过该方法知道数据集总长度,从而计算批次数量
def __getitem__(self, index):
    if isinstance(index, torch.Tensor):
        index = index.item()
    batch_data = dict()
    for key in self.data.keys():
        if self.data[key] is not None:
            batch_data[key] = self.data[key][index]
    return batch_data
  • 功能:根据索引index获取单个样本的输入数据
  • 关键逻辑
    1. 处理索引类型:若index是 PyTorch 张量(如某些场景下的张量索引),转换为 Python 标量避免报错
    2. 遍历self.data的键(input_idsatt_mask),提取对应索引的元素,封装成字典返回
  • 作用DataLoader会循环调用该方法,批量获取样本并组装成批次数据(如一次取 256 个样本)

三、SBERT 模型封装类(PyTorch 模型基类)

class Sentence_Transformer(nn.Module):
  • 功能:继承 PyTorch 的nn.Module(模型基类),封装 SBERT 的核心逻辑:预训练模型加载、句子嵌入池化、前向传播
  • 核心目标:将 BERT 输出的「token 级嵌入」转换为「句子级嵌入」(知识图谱需要的是整个实体 / 关系文本的语义向量,而非单个字符的向量)
def __init__(self, pretrained_repo):
    super(Sentence_Transformer, self).__init__()
    print(f"inherit model weights from {pretrained_repo}")
    self.bert_model = AutoModel.from_pretrained(pretrained_repo)
  • 功能:初始化模型,加载预训练权重
  • 关键解析
    • super(Sentence_Transformer, self).__init__():调用父类nn.Module的构造函数,初始化模型参数管理
    • AutoModel.from_pretrained(pretrained_repo):Hugging Face 的自动模型加载器,根据pretrained_repo自动识别模型类型(RoBERTa),下载并加载预训练权重
    • self.bert_model:存储预训练模型实例,用于提取 token 级嵌入
  • 作用:复用预训练模型的语义理解能力,无需从零训练(节省算力和数据)
def mean_pooling(self, model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    data_type = token_embeddings.dtype
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(data_type)
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
  • 功能:实现「均值池化」,将 token 级嵌入聚合为句子级嵌入(SBERT 的核心操作)
  • 逐句解析
    1. token_embeddings = model_output[0]:BERT 类模型输出是元组,第一个元素是所有 token 的嵌入(形状:[batch_size, 序列长度,隐藏层维度],如 [256, 128, 1024])
    2. data_type = token_embeddings.dtype:获取嵌入的数据类型(如 float32),确保后续计算类型一致
    3. input_mask_expanded = ...:处理注意力掩码:
      • attention_mask.unsqueeze(-1):给掩码添加最后一个维度(从 [256, 128]→[256, 128, 1])
      • expand(token_embeddings.size()):扩展到与 token_embeddings 同形状([256, 128, 1024]),方便逐元素相乘
      • to(data_type):转换为与嵌入一致的数据类型
    4. return ...:计算均值:
      • token_embeddings * input_mask_expanded:无效 token(掩码为 0)的嵌入置为 0,只保留有效 token 的嵌入
      • sum(1):在「序列长度」维度求和([256, 128, 1024]→[256, 1024])
      • input_mask_expanded.sum(1):计算每个样本的有效 token 数量([256, 1024])
      • torch.clamp(..., min=1e-9):避免有效 token 数为 0 时除数为 0(防止 NaN 错误)
  • 作用:将分散的 token 嵌入聚合为整体语义向量,比仅用 [CLS] token 嵌入更稳定(知识图谱中实体描述的语义更完整)
def forward(self, input_ids, att_mask):
    bert_out = self.bert_model(input_ids=input_ids, attention_mask=att_mask)
    sentence_embeddings = self.mean_pooling(bert_out, att_mask)
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings
  • 功能:模型前向传播(推理 / 训练时的核心调用方法)
  • 逐句解析
    1. bert_out = self.bert_model(...):将批次输入(input_ids、注意力掩码)传入预训练模型,得到 token 级嵌入输出
    2. sentence_embeddings = self.mean_pooling(...):调用均值池化,得到句子级嵌入
    3. F.normalize(..., p=2, dim=1):L2 归一化(向量每个元素除以向量的 L2 范数):
      • p=2:L2 范数(欧氏距离)
      • dim=1:对每个样本的嵌入向量归一化
    4. return sentence_embeddings:返回归一化后的句子嵌入(形状:[batch_size, 1024])
  • 关键作用:归一化后,向量的余弦相似度等价于欧氏距离,后续计算实体 / 关系相似度时更高效(知识图谱中实体匹配、关系聚类常用)

四、模型加载函数

def load_sbert():
    model = Sentence_Transformer(pretrained_repo)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_repo)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    return model, tokenizer, device
  • 功能:统一加载模型、分词器和计算设备,简化调用流程
  • 逐句解析
    1. model = Sentence_Transformer(pretrained_repo):创建 SBERT 模型实例,加载预训练权重
    2. tokenizer = AutoTokenizer.from_pretrained(pretrained_repo):加载模型对应的分词器:
      • 分词器作用:将原始文本(如 "实体 A 是人工智能公司")转换为模型能识别的input_idsattention_mask
      • 与模型配套:确保分词规则、词汇表和预训练模型一致(否则嵌入效果会变差)
    3. device = torch.device(...):自动选择计算设备:
      • 有 GPU(CUDA 可用)则用 GPU(推理速度快,适合批量处理)
      • 无 GPU 则用 CPU(速度慢,适合小批量文本)
    4. model.to(device):将模型参数移动到指定设备(GPU/CPU),后续输入数据需与模型在同一设备
    5. model.eval():将模型设置为「评估模式」:
      • 禁用训练时的随机操作(如 Dropout)
      • 确保每次推理结果一致(知识图谱构建需要稳定的嵌入)
    6. return model, tokenizer, device:返回加载好的核心组件,供后续文本转嵌入使用

五、文本转嵌入函数(核心业务逻辑)

def sbert_text2embedding(model, tokenizer, device, text):
    try:
        # 1. 文本编码:转换为模型输入格式
        encoding = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        # 2. 封装数据集
        dataset = Dataset(input_ids=encoding.input_ids, attention_mask=encoding.attention_mask)
        # 3. 构建数据加载器(批量处理)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        # 4. 存储所有批次的嵌入
        all_embeddings = []
        # 5. 批量推理(禁用梯度计算,节省内存)
        with torch.no_grad():
            for batch in dataloader:
                # 5.1 数据移到对应设备
                batch = {key: value.to(device) for key, value in batch.items()}
                # 5.2 前向传播计算嵌入
                embeddings = model(input_ids=batch["input_ids"], att_mask=batch["att_mask"])
                # 5.3 保存当前批次嵌入
                all_embeddings.append(embeddings)
        # 6. 拼接所有批次的嵌入,移到CPU(方便后续处理)
        all_embeddings = torch.cat(all_embeddings, dim=0).cpu()
    except:
        # 异常处理:返回空嵌入(避免程序崩溃)
        return torch.zeros((0, 1024))
    return all_embeddings
  • 功能:接收原始文本(单个字符串或字符串列表),输出对应的 SBERT 嵌入(1024 维向量)
  • 逐块解析
1. 文本编码
encoding = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
  • text:输入文本(支持单个字符串,如 "苹果公司";或列表,如 ["苹果公司", "华为技术有限公司"])
  • 参数说明:
    • padding=True:批次内文本长度不一致时,用 padding token 补全到最长文本长度(确保批次张量形状统一)
    • truncation=True:文本长度超过模型最大输入长度(RoBERTa-large 默认 512)时,截断到最大长度
    • return_tensors='pt':返回 PyTorch 张量格式(与模型兼容)
  • 输出encoding:字典,包含input_ids([样本数,序列长度])和attention_mask([样本数,序列长度])
2. 封装数据集
dataset = Dataset(input_ids=encoding.input_ids, attention_mask=encoding.attention_mask)
  • 将编码后的张量封装到自定义Dataset类,适配后续DataLoader的批量加载
3. 构建 DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  • DataLoader:PyTorch 的批量数据加载工具,自动按batch_size拆分数据集
  • shuffle=False:推理时不打乱顺序(确保输入文本与输出嵌入一一对应,知识图谱中实体文本和嵌入需保持映射关系)
4. 批量推理(核心步骤)
with torch.no_grad():
    for batch in dataloader:
        batch = {key: value.to(device) for key, value in batch.items()}
        embeddings = model(input_ids=batch["input_ids"], att_mask=batch["att_mask"])
        all_embeddings.append(embeddings)
  • with torch.no_grad():禁用梯度计算(推理不需要更新模型权重,节省内存并加快速度)
  • 循环遍历每个批次:
    • 数据移到设备:确保input_idsatt_mask与模型在同一设备(GPU/CPU)
    • 前向传播:调用模型计算该批次的嵌入
    • 保存嵌入:将批次嵌入添加到列表(避免一次性存储所有数据导致内存溢出)
5. 结果拼接与异常处理
all_embeddings = torch.cat(all_embeddings, dim=0).cpu()
  • torch.cat(all_embeddings, dim=0):将所有批次的嵌入在「样本维度」(dim=0)拼接(如 2 个批次各 256 样本→512 样本的嵌入张量)
  • .cpu():将张量移到 CPU(方便后续存储到数据库、计算相似度等操作)
except:
    return torch.zeros((0, 1024))
  • 异常捕获:处理文本格式错误、模型加载失败、内存溢出等问题
  • 返回torch.zeros((0, 1024)):空嵌入张量(样本数 0,维度 1024),确保返回格式统一,避免下游代码报错

六、模型初始化

model, tokenizer, device = load_sbert()
  • 功能:执行模型加载函数,初始化核心组件
  • 后续使用场景
    • 知识图谱实体嵌入:emb = sbert_text2embedding(model, tokenizer, device, "实体A的描述文本")
    • 知识图谱关系嵌入:rel_emb = sbert_text2embedding(model, tokenizer, device, "关系R:属于")
    • 实体匹配:计算两个实体嵌入的余弦相似度,判断是否为同一实体
    • 关系抽取:用实体对嵌入 + 关系描述嵌入训练分类器,抽取实体间关系

整体逻辑总结

这段代码的核心是将文本转换为标准化的语义嵌入向量,专为知识图谱构建设计:

  1. 依赖sentence-transformers的预训练模型,快速获取高质量句子嵌入
  2. 用 PyTorch 的DatasetDataLoader实现批量处理,适配大规模知识图谱(上万条实体 / 关系)
  3. 均值池化 + L2 归一化确保嵌入的语义一致性和相似性计算效率
  4. 异常处理和设备自动选择提升代码健壮性
Logo

为武汉地区的开发者提供学习、交流和合作的平台。社区聚集了众多技术爱好者和专业人士,涵盖了多个领域,包括人工智能、大数据、云计算、区块链等。社区定期举办技术分享、培训和活动,为开发者提供更多的学习和交流机会。

更多推荐