在我们的需求管理产品的AI人工智能开发中,一直存在检索和重排序准确度的困扰,这一问题涉及到分词、切片、嵌入模型、重排序模型和实现等多个问题,以下是我们在重排序实现上的一些理解,希望能抛砖引玉! 我们的大模型应用连接:http://aipoc.chtech.cn:8880/#/login欢迎试用。

重排序原理

1. 什么是重排序

重排序(Reranking)是信息检索和推荐系统中的关键技术,用于对初步检索到的文档进行精细化排序。它通常作为检索流程的最后一步,通过更复杂的模型对候选文档进行重新评分和排序。

2. 基于Transformer的重排序

Qwen3-Reranker-0.6B-ONNX使用Transformer架构,其工作原理如下:

  1. 输入格式:将查询(query)和文档(document)拼接成特定格式,如"query [SEP] document"

  2. 编码处理:通过Tokenizer将文本转换为模型可处理的token ID序列

  3. 特征提取:Transformer模型提取query-document对的深层语义特征

  4. 相关性评分:模型输出query与document之间的相关性分数

  5. 排序优化:根据分数对文档进行重新排序,提高检索质量

3. 重排序的优势

  • 精度更高:比传统BM25等算法能更好地理解语义相关性

  • 上下文感知:能够理解query和document的深层语义关系

  • 个性化排序:可根据具体任务调整排序策略

代码详细解析

1. 类结构和配置

java

@Service
public class RerankingService {
    // 关键配置参数
    private static final int MAX_SEQ_LENGTH = 512;      // 最大序列长度
    private static final int MAX_BATCH_SIZE = 8;        // 最大批处理大小
    private static final int MAX_QUEUE_SIZE = 100;      // 最大队列大小
    private static final int TIMEOUT_MS = 120000;       // 超时时间
    private static final int NUM_THREADS = Math.max(2, Runtime.getRuntime().availableProcessors() - 1);
    
    // 模型路径和组件
    private String MODEL_PATH;
    private String TOKENIZER_PATH;
    private OrtEnvironment env;
    private HuggingFaceTokenizer tokenizer;
    private OrtSession session;
    
    // 线程池和缓存
    private ExecutorService inferenceExecutor;
    private ThreadPoolExecutor batchExecutorForSpringDoc;
    private ThreadPoolExecutor batchExecutorForCustomDoc;
    private final Map<String, Float> scoreCache = new LinkedHashMap<>(1000, 0.75f, true);
    private final BidiMap<String, List<Pair<Double, Document>>> resultCache = new DualHashBidiMap<>();
}

2. 初始化过程

java

@PostConstruct
public void init() throws Exception {
    // 模型路径设置
    this.MODEL_PATH = Paths.get(rerankerModelPath, "model.onnx").toString();
    this.TOKENIZER_PATH = Paths.get(rerankerModelPath, "tokenizer.json").toString();
    
    // 环境初始化
    this.env = onnxEnvManager.getEnvironment();
    initTokenizer();    // 初始化tokenizer
    initSession();      // 初始化ONNX会话
    
    // 线程池初始化
    this.inferenceExecutor = Executors.newFixedThreadPool(NUM_THREADS, new CustomThreadFactory("inference"));
    this.batchExecutorForSpringDoc = new ThreadPoolExecutor(...);
    this.batchExecutorForCustomDoc = new ThreadPoolExecutor(...);
}

3. 核心重排序方法

3.1 主要重排序方法

java

public List<Document> rerank(String query, List<Document> documents) {
    // 1. 输入验证和预处理
    if (documents == null || documents.isEmpty()) return documents;
    
    // 2. 模型重启检查
    if (closed) reinitialize();
    
    // 3. 负载检查
    if (batchExecutorForSpringDoc.getQueue().size() > MAX_QUEUE_SIZE * 0.8) return documents;
    
    // 4. 分批处理
    List<List<Document>> batches = partitionList(documents, maxBatchSize);
    List<Future<List<Document>>> futures = new ArrayList<>();
    
    // 5. 异步提交任务
    for (List<Document> batch : batches) {
        futures.add(docCompletionService.submit(() -> processDocumentBatch(query, batch)));
    }
    
    // 6. 结果收集和排序
    List<Document> results = new ArrayList<>();
    for (Future<List<Document>> future : futures) {
        results.addAll(future.get(TIMEOUT_MS, TimeUnit.MILLISECONDS));
    }
    
    // 7. 按分数降序排序
    results.sort((d1, d2) -> Float.compare(getScoreFromDocument(d2), getScoreFromDocument(d1)));
    
    return results;
}

3.2 批处理过程

java

private List<Document> processDocumentBatch(String query, List<Document> batch) {
    // 1. 准备输入
    List<String> inputs = new ArrayList<>();
    for (Document doc : batch) {
        inputs.add(formatInput(query, doc.getText())); // 格式化为"query [SEP] document"
    }
    
    // 2. 批量预测
    List<Float> scores = batchPredict(inputs);
    
    // 3. 更新文档元数据
    List<Document> results = new ArrayList<>();
    for (int i = 0; i < batch.size(); i++) {
        Document original = batch.get(i);
        float score = scores.get(i);
        
        Map<String, Object> metadata = new HashMap<>(original.getMetadata());
        metadata.put("relevance_score", score);
        
        results.add(original.mutate()
                .score((double) score)
                .metadata(metadata)
                .build());
    }
    
    return results;
}

4. 模型推理核心

java

private List<Float> batchPredict(List<String> texts) throws OrtException {
    // 1. 准备输入张量
    Map<String, OnnxTensor> inputs = prepareBatchInputs(texts);
    
    try (OrtSession.Result results = session.run(inputs)) {
        // 2. 获取输出张量
        OnnxValue outputValue = results.get(0);
        OnnxTensor outputTensor = (OnnxTensor) outputValue;
        
        // 3. 根据输出形状处理不同情况
        long[] shape = outputTensor.getInfo().getShape();
        
        if (shape.length == 3) {
            // 语言模型输出处理
            return processLanguageModelOutput(outputTensor, (int) shape[0]);
        } else if (shape.length == 2 && shape[1] == 1) {
            // 2D输出处理
            return process2DOutput(outputTensor, (int) shape[0]);
        } else if (shape.length == 1) {
            // 1D输出处理
            return process1DOutput(outputTensor, (int) shape[0]);
        }
    }
}

5. 输入预处理

java

private Map<String, OnnxTensor> prepareBatchInputs(List<String> texts) throws OrtException {
    int batchSize = texts.size();
    int maxSeqLength = 0;
    
    // 1. Tokenize所有文本
    List<Encoding> encodings = new ArrayList<>();
    for (String text : texts) {
        Encoding encoding = tokenizer.encode(text);
        encodings.add(encoding);
        maxSeqLength = Math.max(maxSeqLength, Math.min(encoding.getIds().length, MAX_SEQ_LENGTH));
    }
    
    // 2. 构建输入张量
    long[][] inputIds = new long[batchSize][maxSeqLength];
    long[][] attentionMask = new long[batchSize][maxSeqLength];
    long[][] positionIds = new long[batchSize][maxSeqLength];
    
    // 3. 填充和对齐序列
    for (int i = 0; i < batchSize; i++) {
        Encoding encoding = encodings.get(i);
        long[] ids = encoding.getIds();
        int length = Math.min(ids.length, maxSeqLength);
        
        System.arraycopy(ids, 0, inputIds[i], 0, length);
        // 类似处理attention mask和position ids...
        
        // 填充不足部分
        if (length < maxSeqLength) {
            Arrays.fill(inputIds[i], length, maxSeqLength, 0L);
        }
    }
    
    return Map.of(
            "input_ids", OnnxTensor.createTensor(env, inputIds),
            "attention_mask", OnnxTensor.createTensor(env, attentionMask),
            "position_ids", OnnxTensor.createTensor(env, positionIds)
    );
}

6. 性能优化特性

  1. 批处理优化:支持批量处理,提高CPU利用率

  2. 线程池管理:使用多个专用线程池避免阻塞

  3. 缓存机制:LRU缓存避免重复计算

  4. 动态负载均衡:队列满时直接返回原结果

  5. 超时处理:防止长时间阻塞

7. 异常处理和资源管理

java

@PreDestroy
public void close() {
    // 优雅关闭线程池
    shutdownExecutor(batchExecutorForSpringDoc, "Spring文档批处理线程池");
    shutdownExecutor(batchExecutorForCustomDoc, "自定义文档批处理线程池");
    
    // 释放模型资源
    if (session != null) safeClose(session);
    if (tokenizer != null) {
        try { tokenizer.close(); } catch (Exception e) { /* 静默处理 */ }
    }
    
    closed = true;
}

总结

这个重排序服务实现了:

  1. 高效的CPU推理:通过ONNX Runtime优化CPU推理性能

  2. 灵活的批处理:支持动态批处理大小和序列长度

  3. 健壮的错误处理:完善的异常处理和资源管理

  4. 可扩展的架构:支持多种文档类型和排序策略

  5. 性能优化:通过缓存、线程池等技术提高吞吐量

该服务能够在无GPU环境下有效运行Qwen3重排序模型,为检索系统提供高质量的重排序能力。

Logo

更多推荐