C++20数据库连接池完整实现

一、原理

1. 为什么需要连接池?

不用连接池:
┌────────────────────────────────────────────────────────┐
│ 请求到来 → TCP握手 → 数据库认证 → 执行SQL → 断开连接  │
│          ↑_________________________________↑            │
│          每次都要 100~500ms 的连接建立开销              │
└────────────────────────────────────────────────────────┘

用连接池:
┌────────────────────────────────────────────────────────┐
│ 请求到来 → 从池中取连接(~1μs) → 执行SQL → 归还连接    │
│                                                        │
│ 池中始终维护 N 个已建立的连接,随取随用                 │
└────────────────────────────────────────────────────────┘

2. 连接池核心状态机

                    ┌─────────────┐
                    │   创建连接   │
                    └──────┬──────┘
                           │
                    ┌──────▼──────┐
              ┌────►│    空闲      │◄────────────────┐
              │     └──────┬──────┘                  │
              │            │ acquire()                │
              │     ┌──────▼──────┐                  │
              │     │    使用中    │                  │
              │     └──────┬──────┘                  │
              │            │ release()               │
              │     ┌──────▼──────┐                  │
              │     │  健康检查   │─── 通过 ──────────┘
              │     └──────┬──────┘
              │            │ 失败
              │     ┌──────▼──────┐
              └─────│   重新连接  │
                    └──────┬──────┘
                           │ 失败次数过多
                    ┌──────▼──────┐
                    │   销毁连接  │
                    └─────────────┘

3. 整体架构

┌─────────────────────────────────────────────────────────────┐
│                      ConnectionPool                          │
│                                                             │
│  ┌─────────────┐   ┌─────────────┐   ┌─────────────┐      │
│  │  Connection  │   │  Connection  │   │  Connection  │     │
│  │  [空闲]     │   │  [使用中]   │   │  [空闲]     │     │
│  └─────────────┘   └─────────────┘   └─────────────┘      │
│         ↑                                    ↑              │
│  ┌──────┴────────────────────────────────────┴────────┐    │
│  │                   idle_queue_                        │    │
│  │              (条件变量 + 双端队列)                   │    │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  ┌─────────────────────────────────────────────────────┐   │
│  │              后台维护线程                             │   │
│  │   ① 心跳检测(keepalive)  ② 空闲超时回收              │   │
│  │   ③ 最小连接数保持       ④ 统计上报                  │   │
│  └─────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────┘

二、完整实现

文件结构

db_pool/
├── db_error.hpp          # 错误处理
├── db_config.hpp         # 配置
├── db_connection.hpp     # 连接抽象层
├── mysql_connection.hpp  # MySQL具体实现
├── connection_pool.hpp   # 连接池核心
├── pool_guard.hpp        # RAII连接守卫
├── pool_monitor.hpp      # 监控/统计
└── db_pool.hpp           # 统一入口

db_error.hpp

#pragma once
#include <stdexcept>
#include <string>
#include <system_error>
#include <format>

namespace db {

// ── 错误码 ────────────────────────────────────────────────
enum class PoolErrc : int {
    Success          = 0,
    Timeout          = 1,  // 等待连接超时
    PoolExhausted    = 2,  // 连接池已满
    ConnectFailed    = 3,  // 无法连接数据库
    QueryFailed      = 4,  // SQL执行失败
    InvalidConn      = 5,  // 连接已失效
    PoolClosed       = 6,  // 连接池已关闭
    MaxRetryExceeded = 7,  // 重试次数超限
};

// 注册为系统错误类别
struct PoolErrCategory : std::error_category {
    const char* name() const noexcept override { return "db_pool"; }

    std::string message(int ev) const override {
        switch (static_cast<PoolErrc>(ev)) {
            case PoolErrc::Success:          return "success";
            case PoolErrc::Timeout:          return "wait for connection timeout";
            case PoolErrc::PoolExhausted:    return "connection pool exhausted";
            case PoolErrc::ConnectFailed:    return "database connection failed";
            case PoolErrc::QueryFailed:      return "query execution failed";
            case PoolErrc::InvalidConn:      return "invalid connection";
            case PoolErrc::PoolClosed:       return "pool is closed";
            case PoolErrc::MaxRetryExceeded: return "max retry exceeded";
            default:                         return "unknown error";
        }
    }
};

inline const PoolErrCategory& pool_category() {
    static PoolErrCategory cat;
    return cat;
}

inline std::error_code make_error_code(PoolErrc e) {
    return { static_cast<int>(e), pool_category() };
}

// ── 异常类 ────────────────────────────────────────────────
class PoolError : public std::runtime_error {
public:
    explicit PoolError(PoolErrc code, std::string msg = "")
        : std::runtime_error(
            msg.empty() ? pool_category().message(static_cast<int>(code))
                        : std::format("[{}] {}", 
                            pool_category().message(static_cast<int>(code)),
                            msg))
        , code_(code) {}

    PoolErrc code() const noexcept { return code_; }

private:
    PoolErrc code_;
};

// ── 结果类型(不抛异常的API用)────────────────────────────
template<typename T>
class Result {
public:
    static Result ok(T val) {
        Result r; r.val_ = std::move(val); r.ok_ = true; return r;
    }
    static Result err(PoolErrc e, std::string msg = "") {
        Result r; r.ec_ = make_error_code(e);
        r.msg_ = std::move(msg); return r;
    }

    explicit operator bool()     const noexcept { return ok_; }
    const T& value()             const          { check(); return *val_; }
    T&       value()                            { check(); return *val_; }
    std::error_code error()      const noexcept { return ec_; }
    const std::string& message() const noexcept { return msg_; }

private:
    void check() const {
        if (!ok_) throw PoolError(
            static_cast<PoolErrc>(ec_.value()), msg_);
    }

    std::optional<T> val_;
    std::error_code  ec_;
    std::string      msg_;
    bool             ok_ = false;
};

template<>
class Result<void> {
public:
    static Result ok()  { Result r; r.ok_ = true; return r; }
    static Result err(PoolErrc e, std::string msg = "") {
        Result r; r.ec_ = make_error_code(e);
        r.msg_ = std::move(msg); return r;
    }

    explicit operator bool()     const noexcept { return ok_; }
    std::error_code error()      const noexcept { return ec_; }
    const std::string& message() const noexcept { return msg_; }

private:
    std::error_code ec_;
    std::string     msg_;
    bool            ok_ = false;
};

} // namespace db

// 注册错误码
template<>
struct std::is_error_code_enum<db::PoolErrc> : std::true_type {};

db_config.hpp

#pragma once
#include <string>
#include <chrono>
#include <cstddef>

namespace db {

using namespace std::chrono_literals;

// ════════════════════════════════════════════════════════════
//  PoolConfig:连接池配置
// ════════════════════════════════════════════════════════════
struct PoolConfig {
    // ── 数据库连接参数 ────────────────────────────────────
    std::string host     = "127.0.0.1";
    std::uint16_t port   = 3306;
    std::string user     = "root";
    std::string password = "";
    std::string database = "";
    std::string charset  = "utf8mb4";

    // ── 连接池大小 ────────────────────────────────────────
    std::size_t min_connections  = 2;   // 最小连接数(始终保持)
    std::size_t max_connections  = 16;  // 最大连接数
    std::size_t initial_connections = 4; // 启动时预建连接数

    // ── 超时设置 ──────────────────────────────────────────
    std::chrono::milliseconds connect_timeout  = 5s;   // 建立连接超时
    std::chrono::milliseconds acquire_timeout  = 3s;   // 等待获取连接超时
    std::chrono::milliseconds query_timeout    = 30s;  // 查询超时
    std::chrono::seconds      idle_timeout     = 600s; // 空闲连接存活时间
    std::chrono::seconds      keepalive_interval = 60s;// 心跳间隔

    // ── 重试策略 ──────────────────────────────────────────
    int    max_reconnect_attempts = 3;    // 重连最大次数
    std::chrono::milliseconds reconnect_delay = 500ms; // 重连间隔

    // ── 健康检查 ──────────────────────────────────────────
    bool        enable_health_check = true;
    std::string ping_query = "SELECT 1";  // 心跳SQL

    // ── 验证 ──────────────────────────────────────────────
    [[nodiscard]] bool valid() const noexcept {
        return !host.empty()
            && port > 0
            && !user.empty()
            && min_connections <= max_connections
            && max_connections > 0
            && initial_connections <= max_connections;
    }

    // 构建DSN字符串(用于日志)
    [[nodiscard]] std::string dsn() const {
        return std::format("mysql://{}@{}:{}/{}",
            user, host, port, database);
    }
};

// ── Builder模式构造配置 ───────────────────────────────────
class PoolConfigBuilder {
public:
    PoolConfigBuilder& host(std::string h)
        { cfg_.host = std::move(h); return *this; }
    PoolConfigBuilder& port(std::uint16_t p)
        { cfg_.port = p; return *this; }
    PoolConfigBuilder& user(std::string u)
        { cfg_.user = std::move(u); return *this; }
    PoolConfigBuilder& password(std::string pw)
        { cfg_.password = std::move(pw); return *this; }
    PoolConfigBuilder& database(std::string db)
        { cfg_.database = std::move(db); return *this; }
    PoolConfigBuilder& min_conn(std::size_t n)
        { cfg_.min_connections = n; return *this; }
    PoolConfigBuilder& max_conn(std::size_t n)
        { cfg_.max_connections = n; return *this; }
    PoolConfigBuilder& acquire_timeout(std::chrono::milliseconds t)
        { cfg_.acquire_timeout = t; return *this; }
    PoolConfigBuilder& idle_timeout(std::chrono::seconds t)
        { cfg_.idle_timeout = t; return *this; }
    PoolConfigBuilder& keepalive(std::chrono::seconds t)
        { cfg_.keepalive_interval = t; return *this; }

    [[nodiscard]] PoolConfig build() const { return cfg_; }

private:
    PoolConfig cfg_;
};

} // namespace db

db_connection.hpp

#pragma once
#include "db_error.hpp"
#include "db_config.hpp"
#include <string>
#include <vector>
#include <unordered_map>
#include <any>
#include <chrono>
#include <memory>
#include <functional>
#include <variant>

namespace db {

// ── 查询结果行 ────────────────────────────────────────────
using FieldValue = std::variant<
    std::monostate,   // NULL
    std::int64_t,
    double,
    std::string
>;

struct Row {
    std::vector<std::string>      columns;
    std::vector<FieldValue>       values;
    std::unordered_map<std::string, std::size_t> index;

    const FieldValue& operator[](const std::string& col) const {
        auto it = index.find(col);
        if (it == index.end())
            throw PoolError(PoolErrc::QueryFailed,
                std::format("column '{}' not found", col));
        return values[it->second];
    }

    // 类型安全取值
    template<typename T>
    T get(const std::string& col) const {
        return std::get<T>((*this)[col]);
    }

    std::string get_str(const std::string& col) const {
        return get<std::string>(col);
    }
    std::int64_t get_int(const std::string& col) const {
        return get<std::int64_t>(col);
    }
    double get_double(const std::string& col) const {
        return get<double>(col);
    }
    bool is_null(const std::string& col) const {
        return std::holds_alternative<std::monostate>((*this)[col]);
    }
};

struct QueryResult {
    std::vector<Row>  rows;
    std::uint64_t     affected_rows = 0;
    std::uint64_t     last_insert_id = 0;
    std::size_t       field_count = 0;

    bool   empty()  const noexcept { return rows.empty(); }
    std::size_t size() const noexcept { return rows.size(); }
    const Row& operator[](std::size_t i) const { return rows[i]; }

    // 范围遍历
    auto begin() const { return rows.begin(); }
    auto end()   const { return rows.end(); }
};

// ── 预处理语句参数 ────────────────────────────────────────
using Param = std::variant<
    std::monostate,
    std::int64_t,
    double,
    std::string,
    bool
>;

using Params = std::vector<Param>;

// ════════════════════════════════════════════════════════════
//  IConnection:数据库连接抽象接口
// ════════════════════════════════════════════════════════════
class IConnection {
public:
    virtual ~IConnection() = default;

    // ── 基本操作 ──────────────────────────────────────────
    virtual bool connect(const PoolConfig& cfg)  = 0;
    virtual void disconnect()                    = 0;
    virtual bool is_connected()            const = 0;
    virtual bool ping()                          = 0;  // 心跳检测

    // ── SQL执行 ───────────────────────────────────────────
    virtual QueryResult execute(const std::string& sql)       = 0;
    virtual QueryResult execute(const std::string& sql,
                                const Params& params)         = 0;

    // ── 事务 ──────────────────────────────────────────────
    virtual void begin_transaction()   = 0;
    virtual void commit()              = 0;
    virtual void rollback()            = 0;
    virtual bool in_transaction() const= 0;

    // ── 转义(防注入)────────────────────────────────────
    virtual std::string escape(const std::string& s) const = 0;

    // ── 元信息 ────────────────────────────────────────────
    virtual std::uint64_t last_insert_id()  const = 0;
    virtual std::uint64_t affected_rows()   const = 0;
    virtual std::string   server_version()  const = 0;
    virtual std::string   error_message()   const = 0;
    virtual int           error_code()      const = 0;

    // ── 连接池内部使用 ────────────────────────────────────
    std::size_t  pool_id   = 0;       // 在池中的ID
    bool         in_use    = false;   // 是否被借出
    std::chrono::steady_clock::time_point
                 last_used = std::chrono::steady_clock::now();
    std::chrono::steady_clock::time_point
                 created_at= std::chrono::steady_clock::now();
    int          fail_count = 0;      // 连续失败次数
};

using ConnectionPtr = std::shared_ptr<IConnection>;
using ConnectionFactory = std::function<ConnectionPtr()>;

} // namespace db

mysql_connection.hpp

#pragma once
#include "db_connection.hpp"
#include <print>

// 这里用模拟实现,实际项目接入 mysqlclient / mariadb-connector
// #include <mysql/mysql.h>

namespace db {

// ════════════════════════════════════════════════════════════
//  MockMySQLConnection:模拟MySQL连接(可替换为真实实现)
//  真实实现只需把 Mock 换成 mysql_*() 系列函数调用
// ════════════════════════════════════════════════════════════
class MockMySQLConnection : public IConnection {
public:
    MockMySQLConnection() = default;
    ~MockMySQLConnection() override { disconnect(); }

    // ── 连接 ──────────────────────────────────────────────
    bool connect(const PoolConfig& cfg) override {
        cfg_ = cfg;
        // 真实代码:
        // mysql_ = mysql_init(nullptr);
        // mysql_options(mysql_, MYSQL_OPT_CONNECT_TIMEOUT, &timeout);
        // connected_ = (mysql_real_connect(mysql_,
        //     cfg.host.c_str(), cfg.user.c_str(),
        //     cfg.password.c_str(), cfg.database.c_str(),
        //     cfg.port, nullptr, 0) != nullptr);

        // 模拟:随机10%概率连接失败(测试用)
        static int counter = 0;
        connected_ = (++counter % 10 != 0);
        if (connected_) {
            std::println("[Conn#{}] connected to {}", pool_id, cfg.dsn());
        }
        return connected_;
    }

    void disconnect() override {
        if (connected_) {
            // mysql_close(mysql_);
            std::println("[Conn#{}] disconnected", pool_id);
            connected_ = false;
        }
    }

    bool is_connected() const override { return connected_; }

    bool ping() override {
        if (!connected_) return false;
        // return mysql_ping(mysql_) == 0;
        return true;  // 模拟ping成功
    }

    // ── 执行SQL ───────────────────────────────────────────
    QueryResult execute(const std::string& sql) override {
        return execute(sql, {});
    }

    QueryResult execute(const std::string& sql,
                        const Params& params) override
    {
        if (!connected_)
            throw PoolError(PoolErrc::InvalidConn);

        last_used = std::chrono::steady_clock::now();

        // 真实代码:
        // MYSQL_STMT* stmt = mysql_stmt_init(mysql_);
        // mysql_stmt_prepare(stmt, sql.c_str(), sql.size());
        // ... 绑定参数 ...
        // mysql_stmt_execute(stmt);
        // ... 读取结果集 ...

        // 模拟返回结果
        QueryResult result;
        if (sql.find("SELECT") != std::string::npos ||
            sql.find("select") != std::string::npos)
        {
            Row row;
            row.columns = {"id", "name", "value"};
            row.values  = {
                std::int64_t(1),
                std::string("test"),
                3.14
            };
            row.index = {{"id",0}, {"name",1}, {"value",2}};
            result.rows.push_back(row);
            result.field_count = 3;
        }
        result.affected_rows  = 1;
        result.last_insert_id = last_insert_id_;
        return result;
    }

    // ── 事务 ──────────────────────────────────────────────
    void begin_transaction() override {
        if (in_transaction_)
            throw PoolError(PoolErrc::QueryFailed, "already in transaction");
        // mysql_query(mysql_, "BEGIN");
        in_transaction_ = true;
        std::println("[Conn#{}] BEGIN", pool_id);
    }

    void commit() override {
        // mysql_commit(mysql_);
        in_transaction_ = false;
        std::println("[Conn#{}] COMMIT", pool_id);
    }

    void rollback() override {
        // mysql_rollback(mysql_);
        in_transaction_ = false;
        std::println("[Conn#{}] ROLLBACK", pool_id);
    }

    bool in_transaction() const override { return in_transaction_; }

    // ── 其他 ──────────────────────────────────────────────
    std::string escape(const std::string& s) const override {
        // char buf[s.size()*2+1];
        // mysql_real_escape_string(mysql_, buf, s.c_str(), s.size());
        return s;  // 模拟
    }

    std::uint64_t last_insert_id()  const override { return last_insert_id_; }
    std::uint64_t affected_rows()   const override { return affected_rows_;  }
    std::string   server_version()  const override { return "8.0.32-mock";   }
    std::string   error_message()   const override { return err_msg_;        }
    int           error_code()      const override { return err_code_;       }

private:
    PoolConfig    cfg_;
    bool          connected_     = false;
    bool          in_transaction_= false;
    std::uint64_t last_insert_id_= 0;
    std::uint64_t affected_rows_ = 0;
    std::string   err_msg_;
    int           err_code_      = 0;
};

// ── 工厂函数 ──────────────────────────────────────────────
inline ConnectionFactory make_mysql_factory() {
    return []() -> ConnectionPtr {
        return std::make_shared<MockMySQLConnection>();
    };
}

} // namespace db

pool_monitor.hpp

#pragma once
#include <atomic>
#include <chrono>
#include <string>
#include <format>
#include <print>
#include <numeric>
#include <vector>
#include <mutex>

namespace db {

// ════════════════════════════════════════════════════════════
//  PoolStats:连接池统计
// ════════════════════════════════════════════════════════════
struct PoolStats {
    // 连接数
    std::size_t total_connections  = 0;
    std::size_t idle_connections   = 0;
    std::size_t active_connections = 0;
    std::size_t pending_requests   = 0;  // 等待连接的请求数

    // 计数器
    std::uint64_t total_acquired   = 0;  // 总获取次数
    std::uint64_t total_released   = 0;  // 总归还次数
    std::uint64_t total_created    = 0;  // 总创建连接数
    std::uint64_t total_destroyed  = 0;  // 总销毁连接数
    std::uint64_t total_timeouts   = 0;  // 超时次数
    std::uint64_t total_errors     = 0;  // 错误次数
    std::uint64_t total_reconnects = 0;  // 重连次数

    // 延迟统计(单位:微秒)
    double avg_acquire_us  = 0;
    double avg_query_us    = 0;
    double peak_acquire_us = 0;

    // 时间
    std::chrono::steady_clock::time_point start_time
        = std::chrono::steady_clock::now();

    [[nodiscard]] std::string format() const {
        using namespace std::chrono;
        auto uptime = duration_cast<seconds>(
            steady_clock::now() - start_time).count();

        return std::format(
            "╔══════════════ Pool Stats ══════════════╗\n"
            "║ Uptime:      {:>8} s                 ║\n"
            "║ Connections: total={:>3} idle={:>3} active={:>3}║\n"
            "║ Pending:     {:>3} requests waiting     ║\n"
            "║ Acquire:     total={:<8} timeout={:<5}║\n"
            "║ Created:     {:<6}  Destroyed: {:<6}  ║\n"
            "║ Reconnects:  {:<6}  Errors:    {:<6}  ║\n"
            "║ Avg acquire: {:>8.1f} μs               ║\n"
            "║ Peak acq:    {:>8.1f} μs               ║\n"
            "╚════════════════════════════════════════╝",
            uptime,
            total_connections, idle_connections, active_connections,
            pending_requests,
            total_acquired, total_timeouts,
            total_created, total_destroyed,
            total_reconnects, total_errors,
            avg_acquire_us, peak_acquire_us
        );
    }
};

// ════════════════════════════════════════════════════════════
//  PoolMonitor:线程安全统计收集器
// ════════════════════════════════════════════════════════════
class PoolMonitor {
public:
    // ── 记录获取延迟 ──────────────────────────────────────
    void record_acquire(std::chrono::microseconds us, bool success) {
        if (success) {
            ++total_acquired_;
            // 滑动平均
            double d = static_cast<double>(us.count());
            double cur = avg_acquire_us_.load(std::memory_order_relaxed);
            double cnt = static_cast<double>(total_acquired_.load());
            avg_acquire_us_.store(cur + (d - cur) / cnt,
                                  std::memory_order_relaxed);

            // 峰值
            double peak = peak_acquire_us_.load(std::memory_order_relaxed);
            while (d > peak &&
                   !peak_acquire_us_.compare_exchange_weak(
                       peak, d, std::memory_order_relaxed))
            {}
        } else {
            ++total_timeouts_;
        }
    }

    void record_query(std::chrono::microseconds us) {
        ++total_queries_;
        double d   = static_cast<double>(us.count());
        double cur = avg_query_us_.load(std::memory_order_relaxed);
        double cnt = static_cast<double>(total_queries_.load());
        avg_query_us_.store(cur + (d - cur) / cnt,
                            std::memory_order_relaxed);
    }

    void record_created()    { ++total_created_;    }
    void record_destroyed()  { ++total_destroyed_;  }
    void record_released()   { ++total_released_;   }
    void record_error()      { ++total_errors_;     }
    void record_reconnect()  { ++total_reconnects_; }

    // ── 生成快照 ──────────────────────────────────────────
    [[nodiscard]] PoolStats snapshot(
        std::size_t total, std::size_t idle,
        std::size_t pending) const noexcept
    {
        PoolStats s;
        s.total_connections  = total;
        s.idle_connections   = idle;
        s.active_connections = total > idle ? total - idle : 0;
        s.pending_requests   = pending;
        s.total_acquired     = total_acquired_.load();
        s.total_released     = total_released_.load();
        s.total_created      = total_created_.load();
        s.total_destroyed    = total_destroyed_.load();
        s.total_timeouts     = total_timeouts_.load();
        s.total_errors       = total_errors_.load();
        s.total_reconnects   = total_reconnects_.load();
        s.avg_acquire_us     = avg_acquire_us_.load();
        s.avg_query_us       = avg_query_us_.load();
        s.peak_acquire_us    = peak_acquire_us_.load();
        s.start_time         = start_time_;
        return s;
    }

    void print_stats(std::size_t total, std::size_t idle,
                     std::size_t pending) const {
        std::println("{}", snapshot(total, idle, pending).format());
    }

private:
    std::atomic<std::uint64_t> total_acquired_  { 0 };
    std::atomic<std::uint64_t> total_released_  { 0 };
    std::atomic<std::uint64_t> total_created_   { 0 };
    std::atomic<std::uint64_t> total_destroyed_ { 0 };
    std::atomic<std::uint64_t> total_timeouts_  { 0 };
    std::atomic<std::uint64_t> total_errors_    { 0 };
    std::atomic<std::uint64_t> total_reconnects_{ 0 };
    std::atomic<std::uint64_t> total_queries_   { 0 };
    std::atomic<double>        avg_acquire_us_  { 0.0 };
    std::atomic<double>        avg_query_us_    { 0.0 };
    std::atomic<double>        peak_acquire_us_ { 0.0 };
    std::chrono::steady_clock::time_point start_time_
        = std::chrono::steady_clock::now();
};

} // namespace db

connection_pool.hpp(核心)

#pragma once
#include "db_connection.hpp"
#include "pool_monitor.hpp"
#include <deque>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <atomic>
#include <ranges>
#include <stop_token>  // C++20
#include <semaphore>   // C++20
#include <print>

namespace db {

// ════════════════════════════════════════════════════════════
//  ConnectionPool:核心连接池
// ════════════════════════════════════════════════════════════
class ConnectionPool {
public:
    // ── 构造/析构 ─────────────────────────────────────────
    explicit ConnectionPool(PoolConfig cfg,
                            ConnectionFactory factory)
        : cfg_(std::move(cfg))
        , factory_(std::move(factory))
        , capacity_sem_(cfg_.max_connections)  // 信号量控制最大连接数
    {
        if (!cfg_.valid())
            throw PoolError(PoolErrc::ConnectFailed, "invalid config");

        // 预建初始连接
        for (std::size_t i = 0; i < cfg_.initial_connections; ++i) {
            if (auto conn = create_connection()) {
                std::scoped_lock lock(mutex_);
                idle_queue_.push_back(std::move(conn));
            }
        }

        // 启动后台维护线程(C++20 jthread + stop_token)
        maintenance_thread_ = std::jthread([this](std::stop_token st) {
            maintenance_loop(st);
        });

        std::println("[Pool] started: dsn={} init={}/max={}",
            cfg_.dsn(),
            idle_queue_.size(),
            cfg_.max_connections);
    }

    ~ConnectionPool() {
        close();
    }

    // ── 获取连接(阻塞直到超时)──────────────────────────
    [[nodiscard]] ConnectionPtr acquire() {
        if (closed_.load(std::memory_order_acquire))
            throw PoolError(PoolErrc::PoolClosed);

        auto t0 = std::chrono::steady_clock::now();
        ++pending_requests_;

        // 用RAII确保 pending_requests_ 正确递减
        struct PendingGuard {
            std::atomic<std::size_t>& ref;
            ~PendingGuard() { --ref; }
        } guard{ pending_requests_ };

        // ① 先尝试从idle队列直接取(无锁快路径)
        {
            std::unique_lock lock(mutex_);
            if (!idle_queue_.empty()) {
                return pop_idle(lock, t0);
            }
        }

        // ② idle队列空,尝试创建新连接(信号量限制上限)
        // try_acquire_for:等待最多 acquire_timeout
        if (capacity_sem_.try_acquire_for(cfg_.acquire_timeout)) {
            // 成功获取信号量 → 可以新建连接
            auto conn = create_connection();
            if (conn) {
                record_acquire(t0, true);
                conn->in_use = true;
                return conn;
            }
            // 建立失败归还信号量
            capacity_sem_.release();
        }

        // ③ 等待其他线程归还连接
        {
            std::unique_lock lock(mutex_);
            bool got = cond_.wait_for(lock, cfg_.acquire_timeout,
                [this] {
                    return !idle_queue_.empty() ||
                           closed_.load(std::memory_order_relaxed);
                });

            if (closed_.load(std::memory_order_relaxed))
                throw PoolError(PoolErrc::PoolClosed);

            if (!got) {
                // 超时
                monitor_.record_acquire(elapsed_us(t0), false);
                throw PoolError(PoolErrc::Timeout,
                    std::format("waited {}ms, pool size={}",
                        cfg_.acquire_timeout.count(),
                        total_count_.load()));
            }

            return pop_idle(lock, t0);
        }
    }

    // ── 归还连接 ──────────────────────────────────────────
    void release(ConnectionPtr conn) {
        if (!conn) return;

        conn->in_use   = false;
        conn->last_used= std::chrono::steady_clock::now();

        // 事务泄漏检测
        if (conn->in_transaction()) {
            std::println("[Pool] WARN: connection#{} released "
                         "with active transaction, rolling back",
                conn->pool_id);
            try { conn->rollback(); } catch (...) {}
        }

        // 健康检查
        if (!conn->is_connected()) {
            handle_bad_connection(conn);
            return;
        }

        // 归还到idle队列
        {
            std::scoped_lock lock(mutex_);
            idle_queue_.push_back(std::move(conn));
        }
        monitor_.record_released();
        cond_.notify_one();
    }

    // ── 关闭连接池 ────────────────────────────────────────
    void close() {
        bool expected = false;
        if (!closed_.compare_exchange_strong(expected, true))
            return;  // 已关闭

        // 停止后台线程
        maintenance_thread_.request_stop();
        cond_.notify_all();
        maintenance_thread_.join();

        // 断开所有连接
        std::scoped_lock lock(mutex_);
        for (auto& conn : idle_queue_) {
            conn->disconnect();
            monitor_.record_destroyed();
        }
        idle_queue_.clear();
        std::println("[Pool] closed, destroyed {} connections",
            total_count_.load());
        total_count_ = 0;
    }

    // ── 统计 ──────────────────────────────────────────────
    [[nodiscard]] PoolStats stats() const {
        std::scoped_lock lock(mutex_);
        return monitor_.snapshot(
            total_count_.load(),
            idle_queue_.size(),
            pending_requests_.load());
    }

    void print_stats() const {
        std::scoped_lock lock(mutex_);
        monitor_.print_stats(
            total_count_.load(),
            idle_queue_.size(),
            pending_requests_.load());
    }

    // ── 查询(直接执行,自动管理连接)───────────────────
    QueryResult query(const std::string& sql,
                      const Params& params = {})
    {
        auto conn = acquire();
        struct ReleaseGuard {
            ConnectionPool& pool;
            ConnectionPtr   conn;
            ~ReleaseGuard() { pool.release(std::move(conn)); }
        } guard{ *this, conn };

        auto t0 = std::chrono::steady_clock::now();
        try {
            auto result = conn->execute(sql, params);
            monitor_.record_query(elapsed_us(t0));
            return result;
        } catch (...) {
            monitor_.record_error();
            throw;
        }
    }

    // ── 事务(RAII封装)──────────────────────────────────
    template<std::invocable<IConnection&> Fn>
    auto transaction(Fn&& fn) -> std::invoke_result_t<Fn, IConnection&>
    {
        auto conn = acquire();
        struct ReleaseGuard {
            ConnectionPool& pool;
            ConnectionPtr   conn;
            ~ReleaseGuard() { pool.release(std::move(conn)); }
        } guard{ *this, conn };

        conn->begin_transaction();
        try {
            auto result = std::forward<Fn>(fn)(*conn);
            conn->commit();
            return result;
        } catch (...) {
            try { conn->rollback(); } catch (...) {}
            monitor_.record_error();
            throw;
        }
    }

    // ── 配置 ──────────────────────────────────────────────
    const PoolConfig& config() const noexcept { return cfg_; }

private:
    // ── 创建新连接 ────────────────────────────────────────
    [[nodiscard]] ConnectionPtr create_connection() {
        for (int attempt = 0;
             attempt < cfg_.max_reconnect_attempts; ++attempt)
        {
            try {
                auto conn = factory_();
                conn->pool_id = ++next_id_;

                if (conn->connect(cfg_)) {
                    ++total_count_;
                    monitor_.record_created();
                    return conn;
                }
            } catch (const std::exception& e) {
                std::println("[Pool] connect attempt {}/{} failed: {}",
                    attempt + 1, cfg_.max_reconnect_attempts, e.what());
            }

            if (attempt + 1 < cfg_.max_reconnect_attempts) {
                std::this_thread::sleep_for(cfg_.reconnect_delay);
            }
        }

        monitor_.record_error();
        return nullptr;
    }

    // ── 从idle队列弹出连接 ────────────────────────────────
    [[nodiscard]] ConnectionPtr pop_idle(
        std::unique_lock<std::mutex>& lock,
        std::chrono::steady_clock::time_point t0)
    {
        auto conn = std::move(idle_queue_.front());
        idle_queue_.pop_front();
        lock.unlock();

        conn->in_use = true;
        record_acquire(t0, true);
        return conn;
    }

    // ── 处理坏连接 ────────────────────────────────────────
    void handle_bad_connection(ConnectionPtr& conn) {
        std::println("[Pool] Conn#{} is bad, reconnecting...",
            conn->pool_id);
        monitor_.record_reconnect();

        conn->disconnect();

        // 尝试重连
        if (conn->connect(cfg_)) {
            std::scoped_lock lock(mutex_);
            idle_queue_.push_back(std::move(conn));
            cond_.notify_one();
        } else {
            // 重连失败,销毁连接并释放槽位
            --total_count_;
            monitor_.record_destroyed();
            capacity_sem_.release();

            // 尝试补充最小连接
            ensure_min_connections();
        }
    }

    // ── 确保最小连接数 ────────────────────────────────────
    void ensure_min_connections() {
        std::size_t current;
        {
            std::scoped_lock lock(mutex_);
            current = total_count_.load() ;
        }

        while (current < cfg_.min_connections) {
            if (!capacity_sem_.try_acquire()) break;

            auto conn = create_connection();
            if (conn) {
                std::scoped_lock lock(mutex_);
                idle_queue_.push_back(std::move(conn));
                cond_.notify_one();
                ++current;
            } else {
                capacity_sem_.release();
                break;
            }
        }
    }

    // ── 后台维护线程 ──────────────────────────────────────
    void maintenance_loop(std::stop_token st) {
        using namespace std::chrono;

        while (!st.stop_requested()) {
            // 每 keepalive_interval 执行一次
            std::this_thread::sleep_for(
                std::min(cfg_.keepalive_interval,
                         duration_cast<seconds>(cfg_.idle_timeout / 4)));

            if (st.stop_requested()) break;

            // ① 心跳 + 空闲超时检查
            heartbeat_check();

            // ② 保持最小连接数
            ensure_min_connections();

            // ③ 定期打印统计(调试模式)
            // print_stats();
        }
    }

    void heartbeat_check() {
        using namespace std::chrono;
        auto now = steady_clock::now();

        std::vector<ConnectionPtr> bad_conns;
        std::vector<ConnectionPtr> expired_conns;

        {
            std::scoped_lock lock(mutex_);
            std::deque<ConnectionPtr> survivors;

            for (auto& conn : idle_queue_) {
                auto idle_dur = duration_cast<seconds>(
                    now - conn->last_used);

                // 空闲超时 → 标记为过期
                if (idle_dur > cfg_.idle_timeout &&
                    total_count_.load() > cfg_.min_connections)
                {
                    expired_conns.push_back(std::move(conn));
                    continue;
                }

                // 心跳检测
                if (cfg_.enable_health_check) {
                    if (!conn->ping()) {
                        bad_conns.push_back(std::move(conn));
                        continue;
                    }
                }

                conn->last_used = now;  // 更新心跳时间
                survivors.push_back(std::move(conn));
            }

            idle_queue_ = std::move(survivors);
        }

        // 处理过期连接(超时关闭)
        for (auto& conn : expired_conns) {
            std::println("[Pool] Conn#{} idle timeout, closing",
                conn->pool_id);
            conn->disconnect();
            --total_count_;
            monitor_.record_destroyed();
            capacity_sem_.release();
        }

        // 处理坏连接(尝试重连)
        for (auto& conn : bad_conns) {
            handle_bad_connection(conn);
        }
    }

    // ── 工具函数 ──────────────────────────────────────────
    void record_acquire(
        std::chrono::steady_clock::time_point t0,
        bool success)
    {
        monitor_.record_acquire(elapsed_us(t0), success);
    }

    static std::chrono::microseconds elapsed_us(
        std::chrono::steady_clock::time_point t0) noexcept
    {
        return std::chrono::duration_cast<std::chrono::microseconds>(
            std::chrono::steady_clock::now() - t0);
    }

    // ── 成员变量 ──────────────────────────────────────────
    PoolConfig         cfg_;
    ConnectionFactory  factory_;

    mutable std::mutex      mutex_;
    std::condition_variable cond_;
    std::deque<ConnectionPtr> idle_queue_;   // 空闲连接队列

    // C++20 计数信号量:控制最大连接数
    std::counting_semaphore<> capacity_sem_;

    std::atomic<std::size_t> total_count_     { 0 };
    std::atomic<std::size_t> pending_requests_{ 0 };
    std::atomic<std::size_t> next_id_         { 0 };
    std::atomic<bool>        closed_          { false };

    PoolMonitor monitor_;

    // C++20 jthread:析构时自动join + stop
    std::jthread maintenance_thread_;
};

} // namespace db

pool_guard.hpp

#pragma once
#include "connection_pool.hpp"

namespace db {

// ════════════════════════════════════════════════════════════
//  ConnectionGuard:RAII连接守卫
//  离开作用域自动归还连接
// ════════════════════════════════════════════════════════════
class ConnectionGuard {
public:
    ConnectionGuard(ConnectionPool& pool)
        : pool_(&pool)
        , conn_(pool.acquire())
    {}

    ~ConnectionGuard() {
        if (pool_ && conn_)
            pool_->release(std::move(conn_));
    }

    // 不可拷贝,可移动
    ConnectionGuard(const ConnectionGuard&) = delete;
    ConnectionGuard& operator=(const ConnectionGuard&) = delete;

    ConnectionGuard(ConnectionGuard&& o) noexcept
        : pool_(o.pool_), conn_(std::move(o.conn_)) {
        o.pool_ = nullptr;
    }

    // ── 透明访问 ──────────────────────────────────────────
    IConnection* operator->() const noexcept { return conn_.get(); }
    IConnection& operator*()  const noexcept { return *conn_;      }

    // ── SQL快捷方式 ───────────────────────────────────────
    QueryResult execute(const std::string& sql,
                        const Params& params = {}) {
        return conn_->execute(sql, params);
    }

    // ── 事务支持 ──────────────────────────────────────────
    void begin()    { conn_->begin_transaction(); }
    void commit()   { conn_->commit();   }
    void rollback() { conn_->rollback(); }

    // ── 手动提前归还 ──────────────────────────────────────
    void release() {
        if (pool_ && conn_) {
            pool_->release(std::move(conn_));
            pool_ = nullptr;
        }
    }

private:
    ConnectionPool* pool_;
    ConnectionPtr   conn_;
};

// ════════════════════════════════════════════════════════════
//  TransactionGuard:事务RAII守卫
//  异常时自动回滚,成功需手动commit
// ════════════════════════════════════════════════════════════
class TransactionGuard {
public:
    explicit TransactionGuard(ConnectionGuard& guard)
        : guard_(guard)
    {
        guard_.begin();
    }

    ~TransactionGuard() {
        if (!committed_ && !rolled_back_) {
            try { guard_.rollback(); } catch (...) {}
        }
    }

    void commit() {
        guard_.commit();
        committed_ = true;
    }

    void rollback() {
        guard_.rollback();
        rolled_back_ = true;
    }

    TransactionGuard(const TransactionGuard&) = delete;
    TransactionGuard& operator=(const TransactionGuard&) = delete;

private:
    ConnectionGuard& guard_;
    bool committed_   = false;
    bool rolled_back_ = false;
};

} // namespace db

db_pool.hpp(统一入口)

#pragma once
#include "connection_pool.hpp"
#include "pool_guard.hpp"
#include "mysql_connection.hpp"

namespace db {

// ── 全局单例连接池 ────────────────────────────────────────
class DB {
public:
    static void init(PoolConfig cfg) {
        instance() = std::make_unique<ConnectionPool>(
            std::move(cfg), make_mysql_factory());
    }

    static ConnectionPool& pool() {
        auto& inst = instance();
        if (!inst) throw PoolError(PoolErrc::PoolClosed,
            "DB::init() not called");
        return *inst;
    }

    static ConnectionGuard conn() {
        return ConnectionGuard(pool());
    }

    static QueryResult query(const std::string& sql,
                             const Params& params = {}) {
        return pool().query(sql, params);
    }

    template<std::invocable<IConnection&> Fn>
    static auto transaction(Fn&& fn) {
        return pool().transaction(std::forward<Fn>(fn));
    }

    static void shutdown() { instance().reset(); }

private:
    static std::unique_ptr<ConnectionPool>& instance() {
        static std::unique_ptr<ConnectionPool> inst;
        return inst;
    }
};

} // namespace db

三、完整使用示例

#include "db_pool.hpp"
#include <thread>
#include <vector>
#include <print>

using namespace db;
using namespace std::chrono_literals;

// ── 基本使用 ──────────────────────────────────────────────
void basic_example() {
    std::println("\n=== 基本使用 ===");

    // 1. ConnectionGuard(推荐)
    {
        ConnectionGuard g(DB::pool());
        auto result = g.execute("SELECT id, name FROM users WHERE id = ?",
                                { std::int64_t(1) });

        for (const auto& row : result) {
            std::println("id={} name={}",
                row.get_int("id"),
                row.get_str("name"));
        }
    } // ← 连接自动归还

    // 2. 直接查询(更简洁)
    auto result = DB::query("SELECT COUNT(*) as cnt FROM orders");
    std::println("orders count = {}", result[0].get_int("cnt"));
}

// ── 事务使用 ──────────────────────────────────────────────
void transaction_example() {
    std::println("\n=== 事务 ===");

    // 方式1:lambda事务
    DB::transaction([](IConnection& conn) -> void {
        conn.execute("INSERT INTO accounts(name, balance) VALUES(?,?)",
                     { std::string("Alice"), std::int64_t(1000) });

        conn.execute("UPDATE accounts SET balance = balance - ? WHERE id=?",
                     { std::int64_t(100), std::int64_t(1) });

        conn.execute("UPDATE accounts SET balance = balance + ? WHERE id=?",
                     { std::int64_t(100), std::int64_t(2) });

        // 抛出异常 → 自动rollback
        // throw std::runtime_error("test rollback");
    });

    // 方式2:TransactionGuard(更细粒度控制)
    {
        ConnectionGuard  cg(DB::pool());
        TransactionGuard tg(cg);     // BEGIN

        try {
            cg.execute("DELETE FROM temp_data WHERE expired = 1");
            cg.execute("INSERT INTO logs(msg) VALUES(?)",
                       { std::string("cleaned") });
            tg.commit();             // COMMIT
        } catch (...) {
            // tg析构时自动 ROLLBACK
            throw;
        }
    }
}

// ── 多线程并发测试 ─────────────────────────────────────────
void concurrent_test() {
    std::println("\n=== 并发测试 ===");

    constexpr int THREADS = 20;
    constexpr int OPS     = 50;

    std::vector<std::thread> workers;
    std::atomic<int> success{ 0 }, failed{ 0 };

    for (int t = 0; t < THREADS; ++t) {
        workers.emplace_back([&, t] {
            for (int i = 0; i < OPS; ++i) {
                try {
                    auto res = DB::query(
                        "SELECT ? + ? AS result",
                        { std::int64_t(t), std::int64_t(i) });
                    ++success;
                } catch (const PoolError& e) {
                    ++failed;
                    std::println("[T{}] error: {}", t, e.what());
                }

                // 随机模拟业务处理时间
                if (i % 5 == 0)
                    std::this_thread::sleep_for(1ms);
            }
        });
    }

    for (auto& w : workers) w.join();

    std::println("success={} failed={}", success.load(), failed.load());
    DB::pool().print_stats();
}

// ── 连接池压力测试 ────────────────────────────────────────
void stress_test() {
    std::println("\n=== 压力测试 ===");

    constexpr int N = 1000;
    auto t0 = std::chrono::high_resolution_clock::now();

    for (int i = 0; i < N; ++i) {
        DB::query("SELECT 1");
    }

    auto t1 = std::chrono::high_resolution_clock::now();
    auto us = std::chrono::duration_cast<
        std::chrono::microseconds>(t1 - t0).count();

    std::println("{}次查询耗时 {} μs,平均 {:.1f} μs/次",
        N, us, double(us) / N);
}

int main() {
    // ── 初始化连接池 ──────────────────────────────────────
    auto cfg = PoolConfigBuilder{}
        .host("127.0.0.1")
        .port(3306)
        .user("root")
        .password("123456")
        .database("testdb")
        .min_conn(2)
        .max_conn(10)
        .acquire_timeout(3s)
        .idle_timeout(300s)
        .keepalive(60s)
        .build();

    DB::init(cfg);

    // ── 运行示例 ──────────────────────────────────────────
    basic_example();
    transaction_example();
    concurrent_test();
    stress_test();

    DB::pool().print_stats();
    DB::shutdown();
    return 0;
}

四、原理总结

┌──────────────────────────────────────────────────────────────┐
│                     acquire() 流程                            │
│                                                              │
│  ┌─────────┐    idle队列有连接?                              │
│  │ 请求到来 │──────── YES ──────► 直接取走  O(1) ◄── 快路径 │
│  └─────────┘                                                 │
│       │ NO                                                   │
│       ▼                                                      │
│  总连接数 < max?                                             │
│  ──── YES ──► 信号量acquire ──► 新建连接 ──► 返回           │
│       │ NO                                                   │
│       ▼                                                      │
│  条件变量等待(带超时) ──► 被release()唤醒 ──► 取连接        │
│       │ 超时                                                 │
│       ▼                                                      │
│  抛出 PoolError::Timeout                                     │
└──────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────┐
│                    后台维护线程                               │
│                                                              │
│  每 keepalive_interval 秒执行:                              │
│  ① ping所有空闲连接 → 失败则重连 → 重连失败则销毁           │
│  ② 检查空闲超时 → 超过idle_timeout且数量>min → 关闭         │
│  ③ 当前连接数 < min_connections → 补充新连接                │
└──────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────┐
│                   C++20 特性使用                              │
├─────────────────────┬────────────────────────────────────────┤
│ std::jthread        │ 后台线程,析构自动stop+join            │
│ std::stop_token     │ 优雅停止后台线程                       │
│ std::counting_      │ 控制最大连接数上限(替代手写计数器)   │
│   semaphore         │                                        │
│ std::format/println │ 结构化日志输出                         │
│ std::atomic<double> │ 无锁统计均值                           │
│ concepts/ranges     │ 类型约束与范围操作                     │
└─────────────────────┴────────────────────────────────────────┘

更多推荐