AI人工智能之重排序原理与代码详解
本文介绍了基于Transformer架构的企业级需求管理系统中的重排序技术实现。系统采用Qwen3-Reranker-0.6B-ONNX模型,通过批处理、线程池优化和缓存机制提升性能,支持最大512序列长度和8批量大小。核心功能包括查询-文档语义匹配、动态负载均衡、异常恢复机制等,实现了无GPU环境下高效运行,为信息检索提供精准的重排序能力。
在我们的需求管理产品的AI人工智能开发中,一直存在检索和重排序准确度的困扰,这一问题涉及到分词、切片、嵌入模型、重排序模型和实现等多个问题,以下是我们在重排序实现上的一些理解,希望能抛砖引玉! 我们的大模型应用连接:http://aipoc.chtech.cn:8880/#/login欢迎试用。
重排序原理
1. 什么是重排序
重排序(Reranking)是信息检索和推荐系统中的关键技术,用于对初步检索到的文档进行精细化排序。它通常作为检索流程的最后一步,通过更复杂的模型对候选文档进行重新评分和排序。
2. 基于Transformer的重排序
Qwen3-Reranker-0.6B-ONNX使用Transformer架构,其工作原理如下:
-
输入格式:将查询(query)和文档(document)拼接成特定格式,如"query [SEP] document"
-
编码处理:通过Tokenizer将文本转换为模型可处理的token ID序列
-
特征提取:Transformer模型提取query-document对的深层语义特征
-
相关性评分:模型输出query与document之间的相关性分数
-
排序优化:根据分数对文档进行重新排序,提高检索质量
3. 重排序的优势
-
精度更高:比传统BM25等算法能更好地理解语义相关性
-
上下文感知:能够理解query和document的深层语义关系
-
个性化排序:可根据具体任务调整排序策略
代码详细解析
1. 类结构和配置
java
@Service
public class RerankingService {
// 关键配置参数
private static final int MAX_SEQ_LENGTH = 512; // 最大序列长度
private static final int MAX_BATCH_SIZE = 8; // 最大批处理大小
private static final int MAX_QUEUE_SIZE = 100; // 最大队列大小
private static final int TIMEOUT_MS = 120000; // 超时时间
private static final int NUM_THREADS = Math.max(2, Runtime.getRuntime().availableProcessors() - 1);
// 模型路径和组件
private String MODEL_PATH;
private String TOKENIZER_PATH;
private OrtEnvironment env;
private HuggingFaceTokenizer tokenizer;
private OrtSession session;
// 线程池和缓存
private ExecutorService inferenceExecutor;
private ThreadPoolExecutor batchExecutorForSpringDoc;
private ThreadPoolExecutor batchExecutorForCustomDoc;
private final Map<String, Float> scoreCache = new LinkedHashMap<>(1000, 0.75f, true);
private final BidiMap<String, List<Pair<Double, Document>>> resultCache = new DualHashBidiMap<>();
}
2. 初始化过程
java
@PostConstruct
public void init() throws Exception {
// 模型路径设置
this.MODEL_PATH = Paths.get(rerankerModelPath, "model.onnx").toString();
this.TOKENIZER_PATH = Paths.get(rerankerModelPath, "tokenizer.json").toString();
// 环境初始化
this.env = onnxEnvManager.getEnvironment();
initTokenizer(); // 初始化tokenizer
initSession(); // 初始化ONNX会话
// 线程池初始化
this.inferenceExecutor = Executors.newFixedThreadPool(NUM_THREADS, new CustomThreadFactory("inference"));
this.batchExecutorForSpringDoc = new ThreadPoolExecutor(...);
this.batchExecutorForCustomDoc = new ThreadPoolExecutor(...);
}
3. 核心重排序方法
3.1 主要重排序方法
java
public List<Document> rerank(String query, List<Document> documents) {
// 1. 输入验证和预处理
if (documents == null || documents.isEmpty()) return documents;
// 2. 模型重启检查
if (closed) reinitialize();
// 3. 负载检查
if (batchExecutorForSpringDoc.getQueue().size() > MAX_QUEUE_SIZE * 0.8) return documents;
// 4. 分批处理
List<List<Document>> batches = partitionList(documents, maxBatchSize);
List<Future<List<Document>>> futures = new ArrayList<>();
// 5. 异步提交任务
for (List<Document> batch : batches) {
futures.add(docCompletionService.submit(() -> processDocumentBatch(query, batch)));
}
// 6. 结果收集和排序
List<Document> results = new ArrayList<>();
for (Future<List<Document>> future : futures) {
results.addAll(future.get(TIMEOUT_MS, TimeUnit.MILLISECONDS));
}
// 7. 按分数降序排序
results.sort((d1, d2) -> Float.compare(getScoreFromDocument(d2), getScoreFromDocument(d1)));
return results;
}
3.2 批处理过程
java
private List<Document> processDocumentBatch(String query, List<Document> batch) {
// 1. 准备输入
List<String> inputs = new ArrayList<>();
for (Document doc : batch) {
inputs.add(formatInput(query, doc.getText())); // 格式化为"query [SEP] document"
}
// 2. 批量预测
List<Float> scores = batchPredict(inputs);
// 3. 更新文档元数据
List<Document> results = new ArrayList<>();
for (int i = 0; i < batch.size(); i++) {
Document original = batch.get(i);
float score = scores.get(i);
Map<String, Object> metadata = new HashMap<>(original.getMetadata());
metadata.put("relevance_score", score);
results.add(original.mutate()
.score((double) score)
.metadata(metadata)
.build());
}
return results;
}
4. 模型推理核心
java
private List<Float> batchPredict(List<String> texts) throws OrtException {
// 1. 准备输入张量
Map<String, OnnxTensor> inputs = prepareBatchInputs(texts);
try (OrtSession.Result results = session.run(inputs)) {
// 2. 获取输出张量
OnnxValue outputValue = results.get(0);
OnnxTensor outputTensor = (OnnxTensor) outputValue;
// 3. 根据输出形状处理不同情况
long[] shape = outputTensor.getInfo().getShape();
if (shape.length == 3) {
// 语言模型输出处理
return processLanguageModelOutput(outputTensor, (int) shape[0]);
} else if (shape.length == 2 && shape[1] == 1) {
// 2D输出处理
return process2DOutput(outputTensor, (int) shape[0]);
} else if (shape.length == 1) {
// 1D输出处理
return process1DOutput(outputTensor, (int) shape[0]);
}
}
}
5. 输入预处理
java
private Map<String, OnnxTensor> prepareBatchInputs(List<String> texts) throws OrtException {
int batchSize = texts.size();
int maxSeqLength = 0;
// 1. Tokenize所有文本
List<Encoding> encodings = new ArrayList<>();
for (String text : texts) {
Encoding encoding = tokenizer.encode(text);
encodings.add(encoding);
maxSeqLength = Math.max(maxSeqLength, Math.min(encoding.getIds().length, MAX_SEQ_LENGTH));
}
// 2. 构建输入张量
long[][] inputIds = new long[batchSize][maxSeqLength];
long[][] attentionMask = new long[batchSize][maxSeqLength];
long[][] positionIds = new long[batchSize][maxSeqLength];
// 3. 填充和对齐序列
for (int i = 0; i < batchSize; i++) {
Encoding encoding = encodings.get(i);
long[] ids = encoding.getIds();
int length = Math.min(ids.length, maxSeqLength);
System.arraycopy(ids, 0, inputIds[i], 0, length);
// 类似处理attention mask和position ids...
// 填充不足部分
if (length < maxSeqLength) {
Arrays.fill(inputIds[i], length, maxSeqLength, 0L);
}
}
return Map.of(
"input_ids", OnnxTensor.createTensor(env, inputIds),
"attention_mask", OnnxTensor.createTensor(env, attentionMask),
"position_ids", OnnxTensor.createTensor(env, positionIds)
);
}
6. 性能优化特性
-
批处理优化:支持批量处理,提高CPU利用率
-
线程池管理:使用多个专用线程池避免阻塞
-
缓存机制:LRU缓存避免重复计算
-
动态负载均衡:队列满时直接返回原结果
-
超时处理:防止长时间阻塞
7. 异常处理和资源管理
java
@PreDestroy
public void close() {
// 优雅关闭线程池
shutdownExecutor(batchExecutorForSpringDoc, "Spring文档批处理线程池");
shutdownExecutor(batchExecutorForCustomDoc, "自定义文档批处理线程池");
// 释放模型资源
if (session != null) safeClose(session);
if (tokenizer != null) {
try { tokenizer.close(); } catch (Exception e) { /* 静默处理 */ }
}
closed = true;
}
总结
这个重排序服务实现了:
-
高效的CPU推理:通过ONNX Runtime优化CPU推理性能
-
灵活的批处理:支持动态批处理大小和序列长度
-
健壮的错误处理:完善的异常处理和资源管理
-
可扩展的架构:支持多种文档类型和排序策略
-
性能优化:通过缓存、线程池等技术提高吞吐量
该服务能够在无GPU环境下有效运行Qwen3重排序模型,为检索系统提供高质量的重排序能力。
更多推荐
所有评论(0)