RAG (检索增强生成) 系统演示

功能特性:

  • 🔍 智能文档检索:基于向量相似度检索相关文档
  • 🤖 智能问答:结合检索到的文档生成准确回答
  • 💾 持久化存储:使用Chroma向量数据库保存文档向量
  • 📊 相似度评分:显示检索文档的相关性分数和过滤
  • 📝 动态文档添加:支持运行时添加新文档到知识库
  • 🎯 结构化输出:可选的JSON格式结构化响应
  • 📈 系统状态监控:实时查看文档数量和系统配置

技术实现:

  • 嵌入模型:阿里云文本嵌入API (自定义Embeddings类)
  • 向量存储:Chroma (支持持久化)
  • 大语言模型:通过OpenAI兼容接口调用
  • 文档分割:RecursiveCharacterTextSplitter
  • 检索策略:向量相似度检索 + 阈值过滤
  • 界面:命令行交互界面

核心组件:

  1. AlibabaEmbeddings: 自定义嵌入模型适配器
  2. SimpleRAG: RAG系统核心实现类
  3. OutputModel: 结构化输出数据模型
  4. 交互式聊天界面:支持多种命令和模式切换
import os
from pathlib import Path
from datetime import datetime
from typing import List

# 第三方库导入
from dotenv import load_dotenv
from pydantic import BaseModel
from openai import OpenAI

# LangChain 组件导入
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter

# ============================================================================
# 配置和环境变量加载
# ============================================================================

# 加载环境变量
load_dotenv()

# 获取配置信息并确保非空
API_KEY = os.getenv("API_KEY", "")
BASE_URL = os.getenv("BASE_URL", "") 
LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "")
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "")

# 验证环境变量
if not all([API_KEY, BASE_URL, LLM_MODEL_NAME, EMBEDDING_MODEL_NAME]):
    raise ValueError("请确保 .env 文件中包含所有必要的环境变量")

print("🚀 正在初始化RAG代理系统...")

# ============================================================================
# 数据模型定义
# ============================================================================

class OutputModel(BaseModel):
    """定义RAG系统的结构化输出模型
    
    用于返回包含答案、来源文档和时间戳的结构化响应
    """
    answer: str
    """返回给用户的答案"""
    source_documents: List[str]
    """用于生成答案的文档标题列表"""
    date: str
    """回答的生成日期"""

# ============================================================================
# 自定义嵌入模型适配器
# ============================================================================

class AlibabaEmbeddings(Embeddings):
    """阿里云嵌入模型的自定义实现
    
    适配阿里云的文本嵌入API,使其兼容LangChain的Embeddings接口
    """
    
    def __init__(self, model: str, api_key: str, base_url: str):
        """初始化嵌入模型客户端
        
        Args:
            model: 嵌入模型名称
            api_key: API密钥
            base_url: API基础URL
        """
        self.client = OpenAI(api_key=api_key, base_url=base_url)
        self.model = model
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """批量嵌入文档列表
        
        Args:
            texts: 待嵌入的文本列表
            
        Returns:
            嵌入向量列表
        """
        embeddings = []
        for text in texts:
            response = self.client.embeddings.create(model=self.model, input=text)
            embeddings.append(response.data[0].embedding)
        return embeddings
    
    def embed_query(self, text: str) -> List[float]:
        """嵌入查询文本
        
        Args:
            text: 待嵌入的查询文本
            
        Returns:
            嵌入向量
        """
        response = self.client.embeddings.create(model=self.model, input=text)
        return response.data[0].embedding

# ============================================================================
# 模型和存储初始化
# ============================================================================

# 初始化LLM模型
print("📝 正在初始化LLM模型...")
llm = ChatOpenAI(
    model=LLM_MODEL_NAME,
    api_key=API_KEY, # type: ignore
    base_url=BASE_URL,
    # temperature=0.7  # 可调节创造性
)

# 设置向量数据库持久化目录
charom_dir = Path("/Users/colin/Desktop/home/code/for_work/dir_2025/month10/rag_bot/chroma_db")
print("🔍 正在初始化嵌入模型...")
if not charom_dir.exists():
    charom_dir.mkdir(parents=True)

# 初始化自定义嵌入模型
embeddings = AlibabaEmbeddings(
    model=EMBEDDING_MODEL_NAME,
    api_key=API_KEY, # type: ignore
    base_url=BASE_URL
)

# ============================================================================
# 示例数据准备
# ============================================================================

# 准备示例文档数据(可替换为任何领域的文档内容)
sample_documents = [
    {
        "title": "人工智能基础",
        "content": """
        人工智能(AI)是计算机科学的一个分支,旨在创造能够模拟人类智能的机器。
        AI包括机器学习、深度学习、自然语言处理、计算机视觉等多个子领域。
        机器学习是AI的核心技术之一,通过算法让计算机从数据中学习模式。
        深度学习使用神经网络来处理复杂的数据模式识别任务。
        """
    },
    {
        "title": "RAG技术原理",
        "content": """
        检索增强生成(RAG)是一种结合信息检索和文本生成的AI技术。
        RAG系统首先从知识库中检索相关信息,然后基于检索到的内容生成回答。
        这种方法可以提高AI回答的准确性和时效性,减少幻觉现象。
        RAG通常包含三个主要组件:文档索引、检索器和生成器。
        向量数据库是RAG系统的重要组成部分,用于存储和检索文档向量。
        """
    },
    {
        "title": "LangChain框架",
        "content": """
        LangChain是一个用于构建基于大型语言模型应用的Python框架。
        它提供了丰富的组件来处理文档加载、文本分割、向量存储和代理构建。
        LangChain支持多种向量数据库,包括Chroma、FAISS、Pinecone等。
        代理(Agent)是LangChain的重要概念,可以使用工具来完成复杂任务。
        工具链(Tool Chain)允许代理访问外部API和数据源。
        """
    },
    {
        "title": "Python编程",
        "content": """
        Python是一种高级编程语言,以其简洁的语法和强大的功能库而著名。
        Python在数据科学、机器学习和AI领域应用广泛。
        常用的Python AI库包括:scikit-learn、TensorFlow、PyTorch、pandas等。
        Python的包管理器pip使得安装和管理第三方库变得非常简单。
        虚拟环境(venv)帮助隔离不同项目的依赖关系。
        """
    }
]

# 文档预处理:创建LangChain文档对象
print("📚 正在准备文档数据...")
documents = []
for doc_data in sample_documents:
    doc = Document(
        page_content=f"标题是:{doc_data['title']}.内容是:{doc_data['content']}",
        metadata={"title": doc_data["title"]}
    )
    documents.append(doc)

# 配置文本分割器:将长文档分割成更小的块以提高检索效果
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,  # 每个块的最大字符数
    chunk_overlap=100,  # 块之间的重叠字符数,保持上下文连贯性
    length_function=len,
)

# 执行文档分割
print("✂️ 正在分割文档...")
all_splits = text_splitter.split_documents(documents)
print(f"📄 总共分割出 {len(all_splits)} 个文档块")

# 构建向量存储:将文档块转换为向量并存储
print("🏗️ 正在构建向量存储...")
vector_store = Chroma.from_documents(
    documents=all_splits, 
    embedding=embeddings,
    persist_directory=str(charom_dir)  # 持久化存储路径
)
print("✅ 向量存储构建完成!")

# ============================================================================
# 提示模板定义
# ============================================================================

# 创建RAG系统的提示模板:定义AI如何基于检索到的文档回答问题
rag_prompt = ChatPromptTemplate.from_messages([
    ("system", """你是一个智能助手,具有访问知识库的能力。

基于以下检索到的相关文档来回答用户的问题。如果文档中没有相关信息,请诚实地说明你不知道。

相关文档:
{context}

请基于上述文档回答用户的问题,保持回答的准确性和有用性。用中文回答。

"""),
    ("human", "{question}")
])

# ============================================================================
# RAG系统核心实现
# ============================================================================

class SimpleRAG:
    """简单的RAG系统实现
    
    整合文档检索和生成功能,提供完整的问答服务
    """
    
    def __init__(self, vector_store, llm, prompt_template):
        """初始化RAG系统
        
        Args:
            vector_store: 向量存储实例
            llm: 大语言模型实例
            prompt_template: 提示模板
        """
        self.vector_store = vector_store
        self.llm = llm 
        self.prompt_template = prompt_template
        self.structured_llm = llm.with_structured_output(OutputModel)

    def retrieve_context(self, query: str, k: int = 3, similarity_threshold: float = 0.8):
        """从向量存储中检索相关文档
        
        Args:
            query: 用户查询
            k: 检索文档数量上限
            similarity_threshold: 相似度阈值,距离越小表示越相似
            
        Returns:
            tuple: (格式化的上下文, 文档标题列表, 相似度分数列表)
        """
        print(f"🔍 正在检索相关信息: {query}")
        
        # 执行向量相似度检索,获取文档和距离分数
        retrieved_docs_with_scores = self.vector_store.similarity_search_with_score(query, k=k)
        
        # 根据阈值过滤低相关性文档
        filtered_docs = []
        scores = []
        for doc, score in retrieved_docs_with_scores:
            print(score)  # 输出原始距离分数用于调试
            if score <= similarity_threshold:  # 距离越小表示相似度越高
                filtered_docs.append((doc, score))
                scores.append(score)
        
        print(f"📖 检索到 {len(retrieved_docs_with_scores)} 个文档,其中 {len(filtered_docs)} 个高相关性文档")
        
        # 如果没有符合阈值的文档,返回空结果
        if not filtered_docs:
            print("⚠️ 未找到高相关性文档")
            return "未找到相关信息", [], []
        
        # 格式化检索结果为上下文文本
        context_parts = []
        source_docs = []
        for i, (doc, score) in enumerate(filtered_docs, 1):
            title = doc.metadata.get("title", f"文档{i}")
            content = doc.page_content.strip()
            print(f"   📄 {title} (距离分数: {score:.3f})")  # 显示文档标题和相关性
            context_parts.append(f"【{title}】\n{content}")
            source_docs.append(title)
        
        formatted_context = "\n\n".join(context_parts)
        return formatted_context, source_docs, scores
    
    def answer_question(self, question: str, use_structured_output: bool = False):
        """回答用户问题
        
        Args:
            question: 用户问题
            use_structured_output: 是否使用结构化输出
            
        Returns:
            str 或 OutputModel: 答案内容
        """
        # 检索相关上下文文档
        context, source_docs, scores = self.retrieve_context(question)
        
        # 检查是否找到相关信息
        if context == "未找到相关信息":
            no_info_msg = "抱歉,我无法在知识库中找到与您问题相关的信息。请尝试换个问法或添加相关文档。"
            if use_structured_output:
                return OutputModel(
                    answer=no_info_msg,
                    source_documents=[],
                    date=datetime.now().strftime("%Y-%m-%d")
                )
            return no_info_msg
        
        # 构建包含上下文的完整提示
        prompt = self.prompt_template.format(
            context=context,
            question=question
        )
        
        if use_structured_output:
            # 尝试生成结构化回答
            try:
                response = self.structured_llm.invoke(prompt)
                # 补充检索到的元信息
                response.source_documents = source_docs
                response.date = datetime.now().strftime("%Y-%m-%d")
                return response
            except Exception as e:
                print(f"结构化输出失败,回退到普通输出: {e}")
                # 失败时创建手动构造的结构化输出
                response = self.llm.invoke(prompt)
                return OutputModel(
                    answer=response.content,
                    source_documents=source_docs,
                    date=datetime.now().strftime("%Y-%m-%d")
                )
        else:
            # 生成普通文本回答
            response = self.llm.invoke(prompt)
            return response.content

# ============================================================================
# 系统功能函数
# ============================================================================

def statistics():
    """显示RAG系统当前统计信息和配置"""
    print("📊 RAG系统状态:")
    try:
        # 尝试获取向量存储中的文档数量
        doc_count = vector_store._collection.count()
        print(f"   - 向量存储中的文档块数量: {doc_count}")
    except:
        # 备用方法获取文档数量
        try:
            all_docs = vector_store.get()
            doc_count = len(all_docs['ids']) if all_docs['ids'] else 0
            print(f"   - 向量存储中的文档块数量: {doc_count}")
        except:
            print(f"   - 向量存储中的文档块数量: 无法获取")
    
    # 显示系统配置信息
    print(f"   - 使用的LLM模型: {LLM_MODEL_NAME}")
    print(f"   - 使用的嵌入模型: {EMBEDDING_MODEL_NAME}")
    print(f"   - 持久化目录: {charom_dir}")

def add_document_to_rag():
    """向RAG系统动态添加新文档
    
    演示如何在运行时向知识库添加新的文档内容
    """
    # 示例:添加Rust编程相关文档
    title = "Rust编程入门"
    content = "rust是一种注重性能和安全性的系统编程语言。它通过所有权系统管理内存,避免了传统语言中的许多内存错误。Rust适用于开发高性能应用程序,如操作系统、游戏引擎和嵌入式系统。Rust拥有丰富的生态系统和工具链,包括Cargo包管理器和Crates.io库。Rust社区活跃,提供了大量的学习资源和支持。"
    
    # 创建新文档对象
    new_doc = Document(
        page_content=content,
        metadata={"title": title}
    )
    
    # 分割新文档为适合的块大小
    splits = text_splitter.split_documents([new_doc])
    
    # 添加到现有的向量存储中
    vector_store.add_documents(splits)
    print(f"✅ 已添加新文档: {title},共分割出 {len(splits)} 个块")

def chat_with_rag():
    """交互式聊天界面
    
    提供命令行交互界面,支持多种操作模式:
    - 普通问答模式
    - 结构化输出模式
    - 文档添加功能
    - 系统状态查看
    """
    print("💬 开始与RAG系统对话(输入 'quit' 或 'exit' 退出, 输入 '+' 添加文档, 输入 'struct' 切换结构化输出)")
    print("=" * 50)
    
    use_structured_output = False  # 输出模式标志
    
    while True:
        try:
            # 获取用户输入,显示当前模式
            mode_indicator = '(结构化)' if use_structured_output else ''
            user_input = input(f"\n🧑 你{mode_indicator}: ").strip()
            
            # 处理退出命令
            if user_input.lower() in ['quit', 'exit', '退出', '结束']:
                statistics()  # 显示最终统计信息
                print("👋 再见!")
                break
            
            # 处理添加文档命令
            if user_input == "+":
                add_document_to_rag()
                continue
            
            # 处理切换输出模式命令
            if user_input.lower() == "struct":
                use_structured_output = not use_structured_output
                status = '启用' if use_structured_output else '禁用'
                print(f"🔄 已{status}结构化输出")
                continue
            
            # 跳过空输入
            if not user_input:
                continue
            
            # 处理用户问题
            print("\n🤖 RAG系统正在思考...")
            print("-" * 30)
            
            # 调用RAG系统获取回答
            answer = rag_system.answer_question(user_input, use_structured_output)
            
            # 根据输出模式显示结果
            print("-" * 30)
            if use_structured_output and isinstance(answer, OutputModel):
                # 显示结构化输出的详细信息
                print(f"🤖 助手: {answer.answer}")
                print(f"📚 参考文档: {', '.join(answer.source_documents)}")
                print(f"📅 回答日期: {answer.date}")
            else:
                # 显示普通文本回答
                print(f"🤖 助手: {answer}")
            print("=" * 50)
            
        except KeyboardInterrupt:
            print("\n\n👋 程序被用户中断,再见!")
            break
        except Exception as e:
            print(f"\n❌ 发生错误: {str(e)}")
            print("请重试...")

# ============================================================================
# 主程序入口
# ============================================================================

# 创建RAG系统实例
print("🤖 正在初始化RAG系统...")
rag_system = SimpleRAG(vector_store, llm, rag_prompt)
print("🎉 RAG系统初始化完成!")
print("=" * 50)

if __name__ == "__main__":
    try:
        print("\n🌟 欢迎使用RAG系统演示!")
        chat_with_rag()
    except Exception as e:
        print(f"❌ 系统初始化失败: {str(e)}")
        print("请检查配置和依赖是否正确安装")

Logo

为武汉地区的开发者提供学习、交流和合作的平台。社区聚集了众多技术爱好者和专业人士,涵盖了多个领域,包括人工智能、大数据、云计算、区块链等。社区定期举办技术分享、培训和活动,为开发者提供更多的学习和交流机会。

更多推荐