LangChain检索增强:从冗余过滤到精准排序的检索质量提升机制

目录

  1. 核心定义与价值:检索增强解决什么核心问题?
  2. 底层实现逻辑:检索增强如何提升检索质量?
  3. 代码实践:从基础检索到增强检索如何落地?
  4. 设计考量:为什么LangChain要这样设计检索增强机制?
  5. 替代方案与优化空间

1. 核心定义与价值:检索增强解决什么核心问题?

1.1 核心定义

ContextualCompressionRetriever 是LangChain中的"检索结果精炼工具",它通过包装基础检索器并应用文档压缩策略,实现对检索结果的智能过滤和压缩。重排序机制则作为"相关性优化机制",通过更精细的语义匹配算法重新调整文档的相关性排序。

1.2 未增强基础检索的核心缺陷

传统的基础检索存在三大关键问题:

  1. 冗余信息干扰:检索到的文档包含大量与查询无关的内容,稀释了关键信息
  2. 相关度排序不准:基于简单向量相似度的排序无法捕捉复杂的语义关联
  3. 过长文档消耗Token:未经压缩的长文档会消耗大量LLM处理Token,影响效率和成本

1.3 检索增强在RAG中的核心位置

用户查询
基础检索器
原始检索结果
ContextualCompressionRetriever
文档压缩/过滤
压缩后文档
重排序机制
相关性调整
精炼结果
LLM生成
最终回答

1.4 检索增强的关键特性

  • 语义感知过滤:基于查询上下文智能判断文档片段的相关性
  • 冗余信息剔除:自动移除与查询无关的内容,保留核心信息
  • 相关度重排:使用更精细的语义匹配模型重新排序文档
  • 跨模态适配潜力:支持文本、图像等多模态内容的检索增强

2. 底层实现逻辑:检索增强如何提升检索质量?

2.1 ContextualCompressionRetriever工作原理

ContextualCompressionRetriever采用"基础检索器+压缩器"的组合架构:

# 核心工作流程
def _get_relevant_documents(self, query: str) -> List[Document]:
    # 1. 使用基础检索器获取原始文档
    docs = self.base_retriever.invoke(query)
    
    # 2. 应用压缩器进行文档压缩
    if docs:
        compressed_docs = self.base_compressor.compress_documents(docs, query)
        return list(compressed_docs)
    return []

2.2 不同压缩策略的适用场景

2.2.1 LLMChainExtractor - 智能内容提取
  • 工作机制:使用LLM分析文档内容,提取与查询相关的关键片段
  • 适用场景:需要深度语义理解的复杂查询
  • 优势:高精度的内容提取,能理解复杂的语义关联
  • 劣势:计算成本较高,处理速度相对较慢
2.2.2 EmbeddingsFilter - 向量相似度过滤
  • 工作机制:计算查询与文档的向量相似度,过滤低相关性文档
  • 适用场景:大规模文档集合的快速过滤
  • 优势:处理速度快,计算成本低
  • 劣势:可能错过语义相关但向量距离较远的内容

2.3 重排序核心机制

重排序机制基于"细粒度语义匹配"实现排序优化:

2.3.1 交叉编码器模型
# CrossEncoderReranker工作流程
def compress_documents(self, documents, query):
    # 1. 计算查询-文档对的相似度分数
    scores = self.model.score([(query, doc.page_content) for doc in documents])
    
    # 2. 根据分数重新排序
    docs_with_scores = list(zip(documents, scores))
    result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
    
    # 3. 返回Top-N结果
    return [doc for doc, _ in result[:self.top_n]]
2.3.2 与基础检索的差异对比
维度 基础检索(向量相似度) 重排序(交叉编码器)
计算方式 独立向量相似度计算 查询-文档对联合编码
语义理解 浅层语义匹配 深层语义交互
计算复杂度 O(n) O(n²)
准确性 中等
适用场景 大规模初筛 精细化排序

2.4 协同优化逻辑

检索增强采用"先过滤冗余→再提升排序"的递进增强策略:

  1. 第一阶段-冗余过滤:移除明显不相关的内容,减少噪声干扰
  2. 第二阶段-精准排序:在过滤后的高质量候选集上进行精细排序
  3. 协同效应:过滤减少了重排序的计算负担,重排序提升了最终结果质量

3. 代码实践:从基础检索到增强检索如何落地?

3.1 环境准备

# 安装依赖包
pip install langchain langchain-community langchain-openai
pip install chromadb sentence-transformers
pip install rank-bm25 numpy

3.2 基础实践1:ContextualCompressionRetriever核心用法

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor, EmbeddingsFilter
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader

# 1. 准备文档数据
loader = TextLoader("sample_documents.txt")
documents = loader.load()

# 2. 文档分割
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200
)
splits = text_splitter.split_documents(documents)

# 3. 创建向量存储和基础检索器
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_documents(splits, embeddings)
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})

# 4. 配置LLM压缩器
llm = ChatOpenAI(temperature=0)
compressor = LLMChainExtractor.from_llm(llm)

# 5. 构建压缩检索器
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=base_retriever
)

# 6. 对比检索效果
query = "什么是机器学习的监督学习?"

# 基础检索结果
basic_docs = base_retriever.get_relevant_documents(query)
print(f"基础检索返回 {len(basic_docs)} 个文档")
print(f"平均文档长度: {sum(len(doc.page_content) for doc in basic_docs) / len(basic_docs):.0f} 字符")

# 压缩检索结果
compressed_docs = compression_retriever.get_relevant_documents(query)
print(f"压缩检索返回 {len(compressed_docs)} 个文档")
print(f"平均文档长度: {sum(len(doc.page_content) for doc in compressed_docs) / len(compressed_docs):.0f} 字符")

# 7. 展示压缩效果
for i, doc in enumerate(compressed_docs[:2]):
    print(f"\n压缩文档 {i+1}:")
    print(doc.page_content[:200] + "...")

3.3 基础实践2:重排序机制实现与效果对比

from langchain.retrievers.document_compressors import CrossEncoderReranker
from sentence_transformers import CrossEncoder

# 1. 配置交叉编码器重排序
class SentenceTransformerReranker:
    def __init__(self, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model = CrossEncoder(model_name)
    
    def score(self, query_doc_pairs):
        return self.model.predict(query_doc_pairs)

# 2. 创建重排序器
reranker_model = SentenceTransformerReranker()
reranker = CrossEncoderReranker(
    model=reranker_model,
    top_n=5
)

# 3. 组合压缩和重排序
combined_retriever = ContextualCompressionRetriever(
    base_compressor=reranker,
    base_retriever=base_retriever
)

# 4. 效果对比测试
test_queries = [
    "深度学习的反向传播算法原理",
    "自然语言处理中的注意力机制",
    "推荐系统的协同过滤方法"
]

for query in test_queries:
    print(f"\n查询: {query}")
    
    # 原始检索
    original_docs = base_retriever.get_relevant_documents(query)
    print(f"原始检索Top-3文档标题: {[doc.metadata.get('title', 'Unknown')[:30] for doc in original_docs[:3]]}")
    
    # 重排序后检索
    reranked_docs = combined_retriever.get_relevant_documents(query)
    print(f"重排序后Top-3文档标题: {[doc.metadata.get('title', 'Unknown')[:30] for doc in reranked_docs[:3]]}")

3.4 进阶实践:检索增强全流程与RAG集成

from langchain.chains import RetrievalQA
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers.document_compressors import EmbeddingsFilter

# 1. 构建多阶段压缩管道
def create_enhanced_retriever(base_retriever, llm, embeddings):
    """创建增强检索器:嵌入过滤 + LLM提取 + 重排序"""
    
    # 第一阶段:嵌入相似度过滤
    embeddings_filter = EmbeddingsFilter(
        embeddings=embeddings,
        similarity_threshold=0.6,  # 相似度阈值
        k=20  # 保留前20个文档
    )
    
    # 第二阶段:LLM内容提取
    llm_extractor = LLMChainExtractor.from_llm(llm)
    
    # 第三阶段:重排序
    reranker = CrossEncoderReranker(
        model=SentenceTransformerReranker(),
        top_n=5
    )
    
    # 组合压缩管道
    pipeline_compressor = DocumentCompressorPipeline(
        transformers=[embeddings_filter, llm_extractor, reranker]
    )
    
    # 创建最终的压缩检索器
    enhanced_retriever = ContextualCompressionRetriever(
        base_compressor=pipeline_compressor,
        base_retriever=base_retriever
    )
    
    return enhanced_retriever

# 2. 创建增强RAG系统
enhanced_retriever = create_enhanced_retriever(base_retriever, llm, embeddings)

# 3. 构建QA链
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=enhanced_retriever,
    return_source_documents=True,
    verbose=True
)

# 4. 测试增强RAG效果
def test_rag_system(question):
    """测试RAG系统并展示检索增强效果"""
    print(f"\n问题: {question}")
    
    # 获取增强检索结果
    result = qa_chain({"query": question})
    
    print(f"\n回答: {result['result']}")
    print(f"\n使用的源文档数量: {len(result['source_documents'])}")
    
    # 展示源文档信息
    for i, doc in enumerate(result['source_documents']):
        print(f"\n源文档 {i+1} (长度: {len(doc.page_content)} 字符):")
        print(f"内容预览: {doc.page_content[:150]}...")
        if 'query_similarity_score' in doc.metadata:
            print(f"相似度分数: {doc.metadata['query_similarity_score']:.3f}")

# 5. 运行测试
test_questions = [
    "什么是Transformer架构的核心创新?",
    "如何评估机器学习模型的性能?",
    "深度学习在计算机视觉中有哪些应用?"
]

for question in test_questions:
    test_rag_system(question)

3.5 性能监控与效果评估

import time
from typing import List, Dict

def evaluate_retrieval_performance(retrievers: Dict[str, any], queries: List[str]):
    """评估不同检索器的性能"""
    results = {}
    
    for name, retriever in retrievers.items():
        print(f"\n评估 {name}...")
        
        total_time = 0
        total_docs = 0
        total_length = 0
        
        for query in queries:
            start_time = time.time()
            docs = retriever.get_relevant_documents(query)
            end_time = time.time()
            
            total_time += (end_time - start_time)
            total_docs += len(docs)
            total_length += sum(len(doc.page_content) for doc in docs)
        
        results[name] = {
            'avg_time': total_time / len(queries),
            'avg_docs': total_docs / len(queries),
            'avg_length': total_length / len(queries) if total_docs > 0 else 0
        }
    
    return results

# 性能对比测试
retrievers_to_test = {
    "基础检索": base_retriever,
    "压缩检索": compression_retriever,
    "增强检索": enhanced_retriever
}

performance_results = evaluate_retrieval_performance(retrievers_to_test, test_queries)

# 输出性能报告
print("\n=== 检索性能对比报告 ===")
for name, metrics in performance_results.items():
    print(f"\n{name}:")
    print(f"  平均响应时间: {metrics['avg_time']:.3f}秒")
    print(f"  平均返回文档数: {metrics['avg_docs']:.1f}")
    print(f"  平均文档总长度: {metrics['avg_length']:.0f}字符")

4. 设计考量:为什么LangChain要这样设计检索增强机制?

4.1 模块化与可组合性设计

LangChain采用"基础Retriever + 压缩器 + 重排序器"的模块化架构,具有以下优势:

4.1.1 组件解耦
# 每个组件都有清晰的职责边界
class ContextualCompressionRetriever(BaseRetriever):
    base_compressor: BaseDocumentCompressor  # 压缩策略
    base_retriever: RetrieverLike           # 基础检索器

# 支持任意组合
retriever = ContextualCompressionRetriever(
    base_compressor=任意压缩器,  # LLMChainExtractor, EmbeddingsFilter, etc.
    base_retriever=任意检索器   # VectorStoreRetriever, BM25Retriever, etc.
)
4.1.2 灵活扩展性
  • 水平扩展:可以轻松添加新的压缩器类型(如多模态压缩器)
  • 垂直扩展:可以组合多个压缩器形成处理管道
  • 策略替换:可以根据场景需求动态切换压缩策略

4.2 效果与效率平衡

4.2.1 为何先压缩再重排序?
1000个候选文档
嵌入过滤
100个文档
LLM提取
50个文档
重排序
Top-5文档

这种设计的核心考量:

  1. 计算成本控制:重排序的O(n²)复杂度要求先减少候选集规模
  2. 质量递进提升:每个阶段都在前一阶段的基础上进一步优化
  3. 资源合理分配:将计算资源集中在最有潜力的候选文档上
4.2.2 效率优化策略
# 分层处理策略
def layered_compression_strategy():
    return [
        # 第1层:快速过滤(低成本)
        EmbeddingsFilter(similarity_threshold=0.5, k=50),
        
        # 第2层:内容提取(中成本)
        LLMChainExtractor.from_llm(llm),
        
        # 第3层:精准排序(高成本,小规模)
        CrossEncoderReranker(top_n=5)
    ]

4.3 与RAG下游环节的协同设计

4.3.1 为LLM生成提供"高质量上下文"

检索增强的设计目标是为LLM提供最优质的上下文信息:

# 优化前:原始检索结果
raw_context = """
文档1: [2000字,包含50%无关内容]
文档2: [1800字,包含70%无关内容]  
文档3: [2200字,包含30%无关内容]
总计: 6000字,平均相关度50%
# 优化后:增强检索结果
enhanced_context = """
文档1: [500字,95%相关内容]
文档2: [400字,90%相关内容]
文档3: [600字,85%相关内容]
总计: 1500字,平均相关度90%
4.3.2 Token效率与生成质量双重优化
  • Token效率:压缩后的文档显著减少Token消耗
  • 生成质量:高相关度内容提升LLM回答的准确性和相关性
  • 成本控制:减少API调用成本,特别是对于付费LLM服务

4.4 通用性与扩展性考量

4.4.1 跨检索器兼容性
# 支持多种基础检索器
supported_retrievers = [
    VectorStoreRetriever,    # 向量检索
    BM25Retriever,          # 关键词检索
    EnsembleRetriever,      # 混合检索
    MultiQueryRetriever,    # 多查询检索
    # ... 任何实现BaseRetriever接口的检索器
]
4.4.2 跨模态检索潜力

设计架构天然支持多模态扩展:

# 未来可能的多模态压缩器
class MultiModalCompressor(BaseDocumentCompressor):
    def compress_documents(self, documents, query):
        # 处理文本、图像、音频等多模态内容
        # 根据查询类型选择合适的压缩策略
        pass

5. 替代方案与优化空间

5.1 替代实现方案

5.1.1 不依赖LangChain的检索质量优化

方案1:自定义检索管道

import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder

class CustomRetrievalPipeline:
    def __init__(self, embedding_model, rerank_model):
        self.embedding_model = SentenceTransformer(embedding_model)
        self.rerank_model = CrossEncoder(rerank_model)
    
    def retrieve_and_rerank(self, query, documents, top_k=5):
        # 1. 嵌入相似度初筛
        query_embedding = self.embedding_model.encode([query])
        doc_embeddings = self.embedding_model.encode([doc.content for doc in documents])
        
        similarities = np.dot(query_embedding, doc_embeddings.T)[0]
        top_indices = np.argsort(similarities)[-20:][::-1]  # 取前20个
        
        # 2. 交叉编码器重排序
        candidate_docs = [documents[i] for i in top_indices]
        pairs = [(query, doc.content) for doc in candidate_docs]
        rerank_scores = self.rerank_model.predict(pairs)
        
        # 3. 返回最终结果
        final_indices = np.argsort(rerank_scores)[-top_k:][::-1]
        return [candidate_docs[i] for i in final_indices]

# 优劣势对比
"""
优势:
- 完全自主控制,无框架依赖
- 可以针对特定场景深度优化
- 更轻量级的部署

劣势:
- 需要自行处理更多底层细节
- 缺少LangChain的生态集成
- 维护成本较高
"""

方案2:基于Elasticsearch的检索增强

from elasticsearch import Elasticsearch

class ElasticsearchEnhancedRetrieval:
    def __init__(self, es_client, index_name):
        self.es = es_client
        self.index = index_name
    
    def multi_stage_search(self, query, size=5):
        # 1. 多字段搜索
        search_body = {
            "query": {
                "multi_match": {
                    "query": query,
                    "fields": ["title^2", "content", "keywords"],
                    "type": "best_fields"
                }
            },
            "size": 50,  # 初筛50个
            "_source": ["title", "content", "metadata"]
        }
        
        # 2. 执行搜索
        response = self.es.search(index=self.index, body=search_body)
        
        # 3. 后处理重排序(可集成ML模型)
        candidates = response['hits']['hits']
        reranked = self._rerank_with_ml_model(query, candidates)
        
        return reranked[:size]
    
    def _rerank_with_ml_model(self, query, candidates):
        # 集成机器学习模型进行重排序
        pass

5.2 优化方向与核心代码思路

5.2.1 增强效果提升

动态压缩策略

class AdaptiveCompressionRetriever(BaseRetriever):
    """根据查询类型动态选择压缩策略"""
    
    def __init__(self, base_retriever, compressor_pool):
        self.base_retriever = base_retriever
        self.compressor_pool = compressor_pool
        self.query_classifier = self._init_query_classifier()
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        # 1. 分析查询类型
        query_type = self.query_classifier.classify(query)
        
        # 2. 选择最适合的压缩器
        compressor = self._select_compressor(query_type)
        
        # 3. 执行检索和压缩
        docs = self.base_retriever.invoke(query)
        return compressor.compress_documents(docs, query)
    
    def _select_compressor(self, query_type):
        strategy_map = {
            "factual": self.compressor_pool["llm_extractor"],      # 事实性查询用LLM提取
            "semantic": self.compressor_pool["embedding_filter"],   # 语义查询用嵌入过滤
            "complex": self.compressor_pool["pipeline"]            # 复杂查询用组合管道
        }
        return strategy_map.get(query_type, self.compressor_pool["default"])

多阶段重排序

class MultiStageReranker(BaseDocumentCompressor):
    """多阶段重排序:粗排 + 精排"""
    
    def __init__(self, coarse_ranker, fine_ranker, coarse_k=20, fine_k=5):
        self.coarse_ranker = coarse_ranker  # 快速粗排模型
        self.fine_ranker = fine_ranker      # 精确细排模型
        self.coarse_k = coarse_k
        self.fine_k = fine_k
    
    def compress_documents(self, documents, query, callbacks=None):
        # 第一阶段:粗排(快速筛选)
        coarse_scores = self.coarse_ranker.score(
            [(query, doc.page_content) for doc in documents]
        )
        coarse_indices = np.argsort(coarse_scores)[-self.coarse_k:][::-1]
        coarse_candidates = [documents[i] for i in coarse_indices]
        
        # 第二阶段:精排(精确排序)
        fine_scores = self.fine_ranker.score(
            [(query, doc.page_content) for doc in coarse_candidates]
        )
        fine_indices = np.argsort(fine_scores)[-self.fine_k:][::-1]
        
        return [coarse_candidates[i] for i in fine_indices]

用户反馈驱动优化

class FeedbackEnhancedRetriever(BaseRetriever):
    """基于用户反馈持续优化的检索器"""
    
    def __init__(self, base_retriever, feedback_store):
        self.base_retriever = base_retriever
        self.feedback_store = feedback_store
        self.personalization_model = self._init_personalization_model()
    
    def _get_relevant_documents(self, query: str, user_id: str = None) -> List[Document]:
        # 1. 获取基础检索结果
        docs = self.base_retriever.invoke(query)
        
        # 2. 应用个性化调整
        if user_id:
            user_preferences = self.feedback_store.get_user_preferences(user_id)
            docs = self._apply_personalization(docs, query, user_preferences)
        
        # 3. 应用全局反馈优化
        global_feedback = self.feedback_store.get_global_feedback(query)
        docs = self._apply_global_optimization(docs, global_feedback)
        
        return docs
    
    def collect_feedback(self, query: str, documents: List[Document], 
                        ratings: List[float], user_id: str = None):
        """收集用户反馈并更新模型"""
        feedback_data = {
            "query": query,
            "documents": documents,
            "ratings": ratings,
            "user_id": user_id,
            "timestamp": time.time()
        }
        
        # 存储反馈
        self.feedback_store.store_feedback(feedback_data)
        
        # 异步更新模型
        self._schedule_model_update()
5.2.2 效率优化

重排序结果缓存

import hashlib
from functools import lru_cache

class CachedReranker(BaseDocumentCompressor):
    """带缓存的重排序器"""
    
    def __init__(self, base_reranker, cache_size=1000, ttl=3600):
        self.base_reranker = base_reranker
        self.cache_size = cache_size
        self.ttl = ttl
        self.cache = {}
        self.cache_timestamps = {}
    
    def compress_documents(self, documents, query, callbacks=None):
        # 1. 生成缓存键
        cache_key = self._generate_cache_key(query, documents)
        
        # 2. 检查缓存
        if self._is_cache_valid(cache_key):
            return self.cache[cache_key]
        
        # 3. 执行重排序
        result = self.base_reranker.compress_documents(documents, query, callbacks)
        
        # 4. 更新缓存
        self._update_cache(cache_key, result)
        
        return result
    
    def _generate_cache_key(self, query, documents):
        """生成缓存键"""
        content_hash = hashlib.md5(
            (query + "".join(doc.page_content for doc in documents)).encode()
        ).hexdigest()
        return f"rerank_{content_hash}"
    
    def _is_cache_valid(self, cache_key):
        """检查缓存是否有效"""
        if cache_key not in self.cache:
            return False
        
        timestamp = self.cache_timestamps.get(cache_key, 0)
        return (time.time() - timestamp) < self.ttl

轻量化压缩器

class LightweightCompressor(BaseDocumentCompressor):
    """轻量级压缩器:基于规则和简单统计"""
    
    def __init__(self, max_length=500, relevance_threshold=0.3):
        self.max_length = max_length
        self.relevance_threshold = relevance_threshold
        self.stopwords = self._load_stopwords()
    
    def compress_documents(self, documents, query, callbacks=None):
        compressed_docs = []
        query_terms = self._extract_key_terms(query)
        
        for doc in documents:
            # 1. 计算简单相关性分数
            relevance_score = self._calculate_relevance(doc.page_content, query_terms)
            
            if relevance_score < self.relevance_threshold:
                continue
            
            # 2. 提取关键句子
            key_sentences = self._extract_key_sentences(
                doc.page_content, query_terms, self.max_length
            )
            
            # 3. 创建压缩文档
            compressed_doc = Document(
                page_content=key_sentences,
                metadata={**doc.metadata, "relevance_score": relevance_score}
            )
            compressed_docs.append(compressed_doc)
        
        return compressed_docs
    
    def _calculate_relevance(self, content, query_terms):
        """基于词频的简单相关性计算"""
        content_lower = content.lower()
        matches = sum(1 for term in query_terms if term in content_lower)
        return matches / len(query_terms) if query_terms else 0

批量处理优化

class BatchOptimizedRetriever(BaseRetriever):
    """批量优化的检索器"""
    
    def __init__(self, base_retriever, batch_size=32):
        self.base_retriever = base_retriever
        self.batch_size = batch_size
        self.query_queue = []
        self.result_cache = {}
    
    async def batch_retrieve(self, queries: List[str]) -> List[List[Document]]:
        """批量检索优化"""
        # 1. 分批处理
        results = []
        for i in range(0, len(queries), self.batch_size):
            batch = queries[i:i + self.batch_size]
            batch_results = await self._process_batch(batch)
            results.extend(batch_results)
        
        return results
    
    async def _process_batch(self, batch_queries):
        """并行处理一批查询"""
        tasks = [
            self.base_retriever.ainvoke(query) 
            for query in batch_queries
        ]
        return await asyncio.gather(*tasks)
5.2.3 功能扩展

跨模态检索增强

class MultiModalCompressionRetriever(BaseRetriever):
    """多模态压缩检索器"""
    
    def __init__(self, text_retriever, image_retriever, audio_retriever):
        self.retrievers = {
            "text": text_retriever,
            "image": image_retriever,
            "audio": audio_retriever
        }
        self.modal_fusion_model = self._init_fusion_model()
    
    def _get_relevant_documents(self, query: str, modalities: List[str] = None) -> List[Document]:
        if modalities is None:
            modalities = ["text"]  # 默认只检索文本
        
        # 1. 多模态并行检索
        modal_results = {}
        for modality in modalities:
            if modality in self.retrievers:
                modal_results[modality] = self.retrievers[modality].invoke(query)
        
        # 2. 跨模态融合排序
        fused_results = self.modal_fusion_model.fuse_and_rank(
            query, modal_results
        )
        
        return fused_results
    
    def _init_fusion_model(self):
        """初始化多模态融合模型"""
        class ModalFusionModel:
            def fuse_and_rank(self, query, modal_results):
                # 实现跨模态相关性计算和融合排序
                # 可以使用CLIP等多模态模型
                pass
        
        return ModalFusionModel()

领域适配优化

class DomainAdaptiveRetriever(BaseRetriever):
    """领域自适应检索器"""
    
    def __init__(self, base_retriever, domain_models):
        self.base_retriever = base_retriever
        self.domain_models = domain_models  # 不同领域的专用模型
        self.domain_classifier = self._init_domain_classifier()
    
    def _get_relevant_documents(self, query: str) -> List[Document]:
        # 1. 识别查询领域
        domain = self.domain_classifier.classify(query)
        
        # 2. 选择领域专用模型
        domain_model = self.domain_models.get(domain, self.domain_models["general"])
        
        # 3. 应用领域特定的检索增强
        docs = self.base_retriever.invoke(query)
        enhanced_docs = domain_model.enhance_retrieval(docs, query)
        
        return enhanced_docs
    
    def add_domain_model(self, domain: str, model):
        """动态添加领域模型"""
        self.domain_models[domain] = model

多语言检索增强

class MultilingualCompressionRetriever(BaseRetriever):
    """多语言压缩检索器"""
    
    def __init__(self, base_retriever, translation_service):
        self.base_retriever = base_retriever
        self.translation_service = translation_service
        self.language_detector = self._init_language_detector()
    
    def _get_relevant_documents(self, query: str, target_lang: str = "en") -> List[Document]:
        # 1. 检测查询语言
        query_lang = self.language_detector.detect(query)
        
        # 2. 如果需要,翻译查询
        if query_lang != target_lang:
            translated_query = self.translation_service.translate(
                query, source_lang=query_lang, target_lang=target_lang
            )
        else:
            translated_query = query
        
        # 3. 执行检索
        docs = self.base_retriever.invoke(translated_query)
        
        # 4. 如果需要,翻译结果
        if query_lang != target_lang:
            docs = self._translate_documents(docs, target_lang, query_lang)
        
        return docs
    
    def _translate_documents(self, docs, source_lang, target_lang):
        """翻译文档内容"""
        translated_docs = []
        for doc in docs:
            translated_content = self.translation_service.translate(
                doc.page_content, source_lang=source_lang, target_lang=target_lang
            )
            translated_doc = Document(
                page_content=translated_content,
                metadata={**doc.metadata, "original_language": source_lang}
            )
            translated_docs.append(translated_doc)
        
        return translated_docs

总结

LangChain的检索增强机制通过ContextualCompressionRetriever和重排序机制的巧妙结合,实现了从"冗余过滤到精准排序"的检索质量全面提升。其核心价值在于:

  1. 解决RAG核心痛点:有效应对冗余信息干扰、排序不准确、Token消耗过大等问题
  2. 模块化设计优势:提供灵活的组件组合能力,支持多样化的应用场景
  3. 效果与效率平衡:通过分层处理策略实现检索质量和计算效率的最优平衡
  4. 强大的扩展潜力:支持多模态、多语言、领域自适应等高级功能扩展

检索增强机制作为RAG系统的核心组件,为构建高质量的智能问答系统提供了坚实的技术基础。随着技术的不断发展,我们可以期待更多创新的压缩策略和排序算法的出现,进一步提升检索增强的效果和效率。

Logo

为武汉地区的开发者提供学习、交流和合作的平台。社区聚集了众多技术爱好者和专业人士,涵盖了多个领域,包括人工智能、大数据、云计算、区块链等。社区定期举办技术分享、培训和活动,为开发者提供更多的学习和交流机会。

更多推荐