K_RagRec代码阅读:lm_modelling.py
功能:指定预训练模型的仓库地址(来自 Hugging Face Hub)细节解析:专门优化句子嵌入的库,基于 Hugging Face Transformers 封装,适合语义相似性、文本匹配任务:模型基座是 RoBERTa(优化版 BERT),"large" 表示模型规模(隐藏层维度 1024),专门为句子级嵌入训练,知识图谱中用于将实体描述、关系描述等文本转换为语义向量(后续用于实体匹配、关系
·
这段代码的核心是将文本转换为标准化的语义嵌入向量,专为知识图谱构建设计
1. Sentence Transformer的工作原理
Sentence_Transformer完全按照Sentence Transformer的工作原理设计:
- 使用
AutoModel.from_pretrained加载BERT模型 mean_pooling方法将token embeddings平均池化为句子嵌入F.normalize进行L2归一化,使向量长度为1
这与传统的Word2Vec不同,Sentence Transformer能捕捉更丰富的语义信息。
2. 文本处理流程
sbert_text2embedding函数的处理流程:
- 使用tokenizer将文本转换为token ids和attention mask(注意力掩码)
- 将数据组织成PyTorch Dataset和DataLoader
- 通过模型前向传播得到token embeddings
- 使用
mean_pooling得到句子嵌入 - 进行归一化
一、核心参数定义
pretrained_repo = 'sentence-transformers/all-roberta-large-v1'
- 功能:指定预训练模型的仓库地址(来自 Hugging Face Hub)
- 细节解析:
sentence-transformers:专门优化句子嵌入的库,基于 Hugging Face Transformers 封装,适合语义相似性、文本匹配任务all-roberta-large-v1:模型基座是 RoBERTa(优化版 BERT),"large" 表示模型规模(隐藏层维度 1024),专门为句子级嵌入训练,知识图谱中用于将实体描述、关系描述等文本转换为语义向量(后续用于实体匹配、关系抽取等)
- 作用:确定使用的预训练权重,避免从零训练,直接复用成熟的语义表示能力
batch_size = 256 # Adjust the batch size as needed
- 功能:定义批量推理 / 训练的样本数量
- 细节解析:
- 批次大小平衡「速度」和「内存」:大批次(如 256)推理更快,但占用 GPU/CPU 内存更多;小批次(如 32)内存占用少,但速度慢
- 注释提示 "根据需求调整":需结合硬件配置(如 GPU 显存大小)修改,显存不足时减小数值
- 作用:批量处理文本时控制单次输入模型的样本数,提升效率并避免内存溢出
二、数据集封装类(PyTorch 标准 Dataset)
class Dataset(torch.utils.data.Dataset):
- 功能:继承 PyTorch 的
Dataset基类,封装模型输入数据(input_ids和attention_mask) - 核心意义:PyTorch 的
DataLoader需配合Dataset使用,实现自动批量加载、并行处理,尤其适合大规模文本(知识图谱中可能有上万条实体文本)
def __init__(self, input_ids=None, attention_mask=None):
super().__init__()
self.data = {
"input_ids": input_ids,
"att_mask": attention_mask,
}
- 功能:初始化数据集,接收 tokenizer 编码后的核心输入
- 参数解析:
input_ids:文本经 tokenizer 分词后,映射到词汇表的整数 ID(形状:[样本数,序列长度])attention_mask:注意力掩码(0 表示无效 token,1 表示有效 token),用于告诉模型哪些字符需要关注(如 padding 补全的字符无需关注)
- 逻辑:用字典
self.data统一存储两个核心输入,方便后续按索引提取
def __len__(self):
return self.data["input_ids"].size(0)
- 功能:返回数据集的总样本数
- 细节:
input_ids.size(0)取张量的第一个维度(样本数维度),因为input_ids的形状是「样本数 × 序列长度」 - 作用:
DataLoader通过该方法知道数据集总长度,从而计算批次数量
def __getitem__(self, index):
if isinstance(index, torch.Tensor):
index = index.item()
batch_data = dict()
for key in self.data.keys():
if self.data[key] is not None:
batch_data[key] = self.data[key][index]
return batch_data
- 功能:根据索引
index获取单个样本的输入数据 - 关键逻辑:
- 处理索引类型:若
index是 PyTorch 张量(如某些场景下的张量索引),转换为 Python 标量避免报错 - 遍历
self.data的键(input_ids和att_mask),提取对应索引的元素,封装成字典返回
- 处理索引类型:若
- 作用:
DataLoader会循环调用该方法,批量获取样本并组装成批次数据(如一次取 256 个样本)
三、SBERT 模型封装类(PyTorch 模型基类)
class Sentence_Transformer(nn.Module):
- 功能:继承 PyTorch 的
nn.Module(模型基类),封装 SBERT 的核心逻辑:预训练模型加载、句子嵌入池化、前向传播 - 核心目标:将 BERT 输出的「token 级嵌入」转换为「句子级嵌入」(知识图谱需要的是整个实体 / 关系文本的语义向量,而非单个字符的向量)
def __init__(self, pretrained_repo):
super(Sentence_Transformer, self).__init__()
print(f"inherit model weights from {pretrained_repo}")
self.bert_model = AutoModel.from_pretrained(pretrained_repo)
- 功能:初始化模型,加载预训练权重
- 关键解析:
super(Sentence_Transformer, self).__init__():调用父类nn.Module的构造函数,初始化模型参数管理AutoModel.from_pretrained(pretrained_repo):Hugging Face 的自动模型加载器,根据pretrained_repo自动识别模型类型(RoBERTa),下载并加载预训练权重self.bert_model:存储预训练模型实例,用于提取 token 级嵌入
- 作用:复用预训练模型的语义理解能力,无需从零训练(节省算力和数据)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
data_type = token_embeddings.dtype
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(data_type)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
- 功能:实现「均值池化」,将 token 级嵌入聚合为句子级嵌入(SBERT 的核心操作)
- 逐句解析:
token_embeddings = model_output[0]:BERT 类模型输出是元组,第一个元素是所有 token 的嵌入(形状:[batch_size, 序列长度,隐藏层维度],如 [256, 128, 1024])data_type = token_embeddings.dtype:获取嵌入的数据类型(如 float32),确保后续计算类型一致input_mask_expanded = ...:处理注意力掩码:attention_mask.unsqueeze(-1):给掩码添加最后一个维度(从 [256, 128]→[256, 128, 1])expand(token_embeddings.size()):扩展到与 token_embeddings 同形状([256, 128, 1024]),方便逐元素相乘to(data_type):转换为与嵌入一致的数据类型
return ...:计算均值:token_embeddings * input_mask_expanded:无效 token(掩码为 0)的嵌入置为 0,只保留有效 token 的嵌入sum(1):在「序列长度」维度求和([256, 128, 1024]→[256, 1024])input_mask_expanded.sum(1):计算每个样本的有效 token 数量([256, 1024])torch.clamp(..., min=1e-9):避免有效 token 数为 0 时除数为 0(防止 NaN 错误)
- 作用:将分散的 token 嵌入聚合为整体语义向量,比仅用 [CLS] token 嵌入更稳定(知识图谱中实体描述的语义更完整)
def forward(self, input_ids, att_mask):
bert_out = self.bert_model(input_ids=input_ids, attention_mask=att_mask)
sentence_embeddings = self.mean_pooling(bert_out, att_mask)
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
- 功能:模型前向传播(推理 / 训练时的核心调用方法)
- 逐句解析:
bert_out = self.bert_model(...):将批次输入(input_ids、注意力掩码)传入预训练模型,得到 token 级嵌入输出sentence_embeddings = self.mean_pooling(...):调用均值池化,得到句子级嵌入F.normalize(..., p=2, dim=1):L2 归一化(向量每个元素除以向量的 L2 范数):p=2:L2 范数(欧氏距离)dim=1:对每个样本的嵌入向量归一化
return sentence_embeddings:返回归一化后的句子嵌入(形状:[batch_size, 1024])
- 关键作用:归一化后,向量的余弦相似度等价于欧氏距离,后续计算实体 / 关系相似度时更高效(知识图谱中实体匹配、关系聚类常用)
四、模型加载函数
def load_sbert():
model = Sentence_Transformer(pretrained_repo)
tokenizer = AutoTokenizer.from_pretrained(pretrained_repo)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
return model, tokenizer, device
- 功能:统一加载模型、分词器和计算设备,简化调用流程
- 逐句解析:
model = Sentence_Transformer(pretrained_repo):创建 SBERT 模型实例,加载预训练权重tokenizer = AutoTokenizer.from_pretrained(pretrained_repo):加载模型对应的分词器:- 分词器作用:将原始文本(如 "实体 A 是人工智能公司")转换为模型能识别的
input_ids和attention_mask - 与模型配套:确保分词规则、词汇表和预训练模型一致(否则嵌入效果会变差)
- 分词器作用:将原始文本(如 "实体 A 是人工智能公司")转换为模型能识别的
device = torch.device(...):自动选择计算设备:- 有 GPU(CUDA 可用)则用 GPU(推理速度快,适合批量处理)
- 无 GPU 则用 CPU(速度慢,适合小批量文本)
model.to(device):将模型参数移动到指定设备(GPU/CPU),后续输入数据需与模型在同一设备model.eval():将模型设置为「评估模式」:- 禁用训练时的随机操作(如 Dropout)
- 确保每次推理结果一致(知识图谱构建需要稳定的嵌入)
return model, tokenizer, device:返回加载好的核心组件,供后续文本转嵌入使用
五、文本转嵌入函数(核心业务逻辑)
def sbert_text2embedding(model, tokenizer, device, text):
try:
# 1. 文本编码:转换为模型输入格式
encoding = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
# 2. 封装数据集
dataset = Dataset(input_ids=encoding.input_ids, attention_mask=encoding.attention_mask)
# 3. 构建数据加载器(批量处理)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# 4. 存储所有批次的嵌入
all_embeddings = []
# 5. 批量推理(禁用梯度计算,节省内存)
with torch.no_grad():
for batch in dataloader:
# 5.1 数据移到对应设备
batch = {key: value.to(device) for key, value in batch.items()}
# 5.2 前向传播计算嵌入
embeddings = model(input_ids=batch["input_ids"], att_mask=batch["att_mask"])
# 5.3 保存当前批次嵌入
all_embeddings.append(embeddings)
# 6. 拼接所有批次的嵌入,移到CPU(方便后续处理)
all_embeddings = torch.cat(all_embeddings, dim=0).cpu()
except:
# 异常处理:返回空嵌入(避免程序崩溃)
return torch.zeros((0, 1024))
return all_embeddings
- 功能:接收原始文本(单个字符串或字符串列表),输出对应的 SBERT 嵌入(1024 维向量)
- 逐块解析:
1. 文本编码
encoding = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
text:输入文本(支持单个字符串,如 "苹果公司";或列表,如 ["苹果公司", "华为技术有限公司"])- 参数说明:
padding=True:批次内文本长度不一致时,用 padding token 补全到最长文本长度(确保批次张量形状统一)truncation=True:文本长度超过模型最大输入长度(RoBERTa-large 默认 512)时,截断到最大长度return_tensors='pt':返回 PyTorch 张量格式(与模型兼容)
- 输出
encoding:字典,包含input_ids([样本数,序列长度])和attention_mask([样本数,序列长度])
2. 封装数据集
dataset = Dataset(input_ids=encoding.input_ids, attention_mask=encoding.attention_mask)
- 将编码后的张量封装到自定义
Dataset类,适配后续DataLoader的批量加载
3. 构建 DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
DataLoader:PyTorch 的批量数据加载工具,自动按batch_size拆分数据集shuffle=False:推理时不打乱顺序(确保输入文本与输出嵌入一一对应,知识图谱中实体文本和嵌入需保持映射关系)
4. 批量推理(核心步骤)
with torch.no_grad():
for batch in dataloader:
batch = {key: value.to(device) for key, value in batch.items()}
embeddings = model(input_ids=batch["input_ids"], att_mask=batch["att_mask"])
all_embeddings.append(embeddings)
with torch.no_grad():禁用梯度计算(推理不需要更新模型权重,节省内存并加快速度)- 循环遍历每个批次:
- 数据移到设备:确保
input_ids和att_mask与模型在同一设备(GPU/CPU) - 前向传播:调用模型计算该批次的嵌入
- 保存嵌入:将批次嵌入添加到列表(避免一次性存储所有数据导致内存溢出)
- 数据移到设备:确保
5. 结果拼接与异常处理
all_embeddings = torch.cat(all_embeddings, dim=0).cpu()
torch.cat(all_embeddings, dim=0):将所有批次的嵌入在「样本维度」(dim=0)拼接(如 2 个批次各 256 样本→512 样本的嵌入张量).cpu():将张量移到 CPU(方便后续存储到数据库、计算相似度等操作)
except:
return torch.zeros((0, 1024))
- 异常捕获:处理文本格式错误、模型加载失败、内存溢出等问题
- 返回
torch.zeros((0, 1024)):空嵌入张量(样本数 0,维度 1024),确保返回格式统一,避免下游代码报错
六、模型初始化
model, tokenizer, device = load_sbert()
- 功能:执行模型加载函数,初始化核心组件
- 后续使用场景:
- 知识图谱实体嵌入:
emb = sbert_text2embedding(model, tokenizer, device, "实体A的描述文本") - 知识图谱关系嵌入:
rel_emb = sbert_text2embedding(model, tokenizer, device, "关系R:属于") - 实体匹配:计算两个实体嵌入的余弦相似度,判断是否为同一实体
- 关系抽取:用实体对嵌入 + 关系描述嵌入训练分类器,抽取实体间关系
- 知识图谱实体嵌入:
整体逻辑总结
这段代码的核心是将文本转换为标准化的语义嵌入向量,专为知识图谱构建设计:
- 依赖
sentence-transformers的预训练模型,快速获取高质量句子嵌入 - 用 PyTorch 的
Dataset和DataLoader实现批量处理,适配大规模知识图谱(上万条实体 / 关系) - 均值池化 + L2 归一化确保嵌入的语义一致性和相似性计算效率
- 异常处理和设备自动选择提升代码健壮性
为武汉地区的开发者提供学习、交流和合作的平台。社区聚集了众多技术爱好者和专业人士,涵盖了多个领域,包括人工智能、大数据、云计算、区块链等。社区定期举办技术分享、培训和活动,为开发者提供更多的学习和交流机会。
更多推荐



所有评论(0)