C++20数据库连接池完整实现
·
C++20数据库连接池完整实现
一、原理
1. 为什么需要连接池?
不用连接池:
┌────────────────────────────────────────────────────────┐
│ 请求到来 → TCP握手 → 数据库认证 → 执行SQL → 断开连接 │
│ ↑_________________________________↑ │
│ 每次都要 100~500ms 的连接建立开销 │
└────────────────────────────────────────────────────────┘
用连接池:
┌────────────────────────────────────────────────────────┐
│ 请求到来 → 从池中取连接(~1μs) → 执行SQL → 归还连接 │
│ │
│ 池中始终维护 N 个已建立的连接,随取随用 │
└────────────────────────────────────────────────────────┘
2. 连接池核心状态机
┌─────────────┐
│ 创建连接 │
└──────┬──────┘
│
┌──────▼──────┐
┌────►│ 空闲 │◄────────────────┐
│ └──────┬──────┘ │
│ │ acquire() │
│ ┌──────▼──────┐ │
│ │ 使用中 │ │
│ └──────┬──────┘ │
│ │ release() │
│ ┌──────▼──────┐ │
│ │ 健康检查 │─── 通过 ──────────┘
│ └──────┬──────┘
│ │ 失败
│ ┌──────▼──────┐
└─────│ 重新连接 │
└──────┬──────┘
│ 失败次数过多
┌──────▼──────┐
│ 销毁连接 │
└─────────────┘
3. 整体架构
┌─────────────────────────────────────────────────────────────┐
│ ConnectionPool │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Connection │ │ Connection │ │ Connection │ │
│ │ [空闲] │ │ [使用中] │ │ [空闲] │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ↑ ↑ │
│ ┌──────┴────────────────────────────────────┴────────┐ │
│ │ idle_queue_ │ │
│ │ (条件变量 + 双端队列) │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 后台维护线程 │ │
│ │ ① 心跳检测(keepalive) ② 空闲超时回收 │ │
│ │ ③ 最小连接数保持 ④ 统计上报 │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
二、完整实现
文件结构
db_pool/
├── db_error.hpp # 错误处理
├── db_config.hpp # 配置
├── db_connection.hpp # 连接抽象层
├── mysql_connection.hpp # MySQL具体实现
├── connection_pool.hpp # 连接池核心
├── pool_guard.hpp # RAII连接守卫
├── pool_monitor.hpp # 监控/统计
└── db_pool.hpp # 统一入口
db_error.hpp
#pragma once
#include <stdexcept>
#include <string>
#include <system_error>
#include <format>
namespace db {
// ── 错误码 ────────────────────────────────────────────────
enum class PoolErrc : int {
Success = 0,
Timeout = 1, // 等待连接超时
PoolExhausted = 2, // 连接池已满
ConnectFailed = 3, // 无法连接数据库
QueryFailed = 4, // SQL执行失败
InvalidConn = 5, // 连接已失效
PoolClosed = 6, // 连接池已关闭
MaxRetryExceeded = 7, // 重试次数超限
};
// 注册为系统错误类别
struct PoolErrCategory : std::error_category {
const char* name() const noexcept override { return "db_pool"; }
std::string message(int ev) const override {
switch (static_cast<PoolErrc>(ev)) {
case PoolErrc::Success: return "success";
case PoolErrc::Timeout: return "wait for connection timeout";
case PoolErrc::PoolExhausted: return "connection pool exhausted";
case PoolErrc::ConnectFailed: return "database connection failed";
case PoolErrc::QueryFailed: return "query execution failed";
case PoolErrc::InvalidConn: return "invalid connection";
case PoolErrc::PoolClosed: return "pool is closed";
case PoolErrc::MaxRetryExceeded: return "max retry exceeded";
default: return "unknown error";
}
}
};
inline const PoolErrCategory& pool_category() {
static PoolErrCategory cat;
return cat;
}
inline std::error_code make_error_code(PoolErrc e) {
return { static_cast<int>(e), pool_category() };
}
// ── 异常类 ────────────────────────────────────────────────
class PoolError : public std::runtime_error {
public:
explicit PoolError(PoolErrc code, std::string msg = "")
: std::runtime_error(
msg.empty() ? pool_category().message(static_cast<int>(code))
: std::format("[{}] {}",
pool_category().message(static_cast<int>(code)),
msg))
, code_(code) {}
PoolErrc code() const noexcept { return code_; }
private:
PoolErrc code_;
};
// ── 结果类型(不抛异常的API用)────────────────────────────
template<typename T>
class Result {
public:
static Result ok(T val) {
Result r; r.val_ = std::move(val); r.ok_ = true; return r;
}
static Result err(PoolErrc e, std::string msg = "") {
Result r; r.ec_ = make_error_code(e);
r.msg_ = std::move(msg); return r;
}
explicit operator bool() const noexcept { return ok_; }
const T& value() const { check(); return *val_; }
T& value() { check(); return *val_; }
std::error_code error() const noexcept { return ec_; }
const std::string& message() const noexcept { return msg_; }
private:
void check() const {
if (!ok_) throw PoolError(
static_cast<PoolErrc>(ec_.value()), msg_);
}
std::optional<T> val_;
std::error_code ec_;
std::string msg_;
bool ok_ = false;
};
template<>
class Result<void> {
public:
static Result ok() { Result r; r.ok_ = true; return r; }
static Result err(PoolErrc e, std::string msg = "") {
Result r; r.ec_ = make_error_code(e);
r.msg_ = std::move(msg); return r;
}
explicit operator bool() const noexcept { return ok_; }
std::error_code error() const noexcept { return ec_; }
const std::string& message() const noexcept { return msg_; }
private:
std::error_code ec_;
std::string msg_;
bool ok_ = false;
};
} // namespace db
// 注册错误码
template<>
struct std::is_error_code_enum<db::PoolErrc> : std::true_type {};
db_config.hpp
#pragma once
#include <string>
#include <chrono>
#include <cstddef>
namespace db {
using namespace std::chrono_literals;
// ════════════════════════════════════════════════════════════
// PoolConfig:连接池配置
// ════════════════════════════════════════════════════════════
struct PoolConfig {
// ── 数据库连接参数 ────────────────────────────────────
std::string host = "127.0.0.1";
std::uint16_t port = 3306;
std::string user = "root";
std::string password = "";
std::string database = "";
std::string charset = "utf8mb4";
// ── 连接池大小 ────────────────────────────────────────
std::size_t min_connections = 2; // 最小连接数(始终保持)
std::size_t max_connections = 16; // 最大连接数
std::size_t initial_connections = 4; // 启动时预建连接数
// ── 超时设置 ──────────────────────────────────────────
std::chrono::milliseconds connect_timeout = 5s; // 建立连接超时
std::chrono::milliseconds acquire_timeout = 3s; // 等待获取连接超时
std::chrono::milliseconds query_timeout = 30s; // 查询超时
std::chrono::seconds idle_timeout = 600s; // 空闲连接存活时间
std::chrono::seconds keepalive_interval = 60s;// 心跳间隔
// ── 重试策略 ──────────────────────────────────────────
int max_reconnect_attempts = 3; // 重连最大次数
std::chrono::milliseconds reconnect_delay = 500ms; // 重连间隔
// ── 健康检查 ──────────────────────────────────────────
bool enable_health_check = true;
std::string ping_query = "SELECT 1"; // 心跳SQL
// ── 验证 ──────────────────────────────────────────────
[[nodiscard]] bool valid() const noexcept {
return !host.empty()
&& port > 0
&& !user.empty()
&& min_connections <= max_connections
&& max_connections > 0
&& initial_connections <= max_connections;
}
// 构建DSN字符串(用于日志)
[[nodiscard]] std::string dsn() const {
return std::format("mysql://{}@{}:{}/{}",
user, host, port, database);
}
};
// ── Builder模式构造配置 ───────────────────────────────────
class PoolConfigBuilder {
public:
PoolConfigBuilder& host(std::string h)
{ cfg_.host = std::move(h); return *this; }
PoolConfigBuilder& port(std::uint16_t p)
{ cfg_.port = p; return *this; }
PoolConfigBuilder& user(std::string u)
{ cfg_.user = std::move(u); return *this; }
PoolConfigBuilder& password(std::string pw)
{ cfg_.password = std::move(pw); return *this; }
PoolConfigBuilder& database(std::string db)
{ cfg_.database = std::move(db); return *this; }
PoolConfigBuilder& min_conn(std::size_t n)
{ cfg_.min_connections = n; return *this; }
PoolConfigBuilder& max_conn(std::size_t n)
{ cfg_.max_connections = n; return *this; }
PoolConfigBuilder& acquire_timeout(std::chrono::milliseconds t)
{ cfg_.acquire_timeout = t; return *this; }
PoolConfigBuilder& idle_timeout(std::chrono::seconds t)
{ cfg_.idle_timeout = t; return *this; }
PoolConfigBuilder& keepalive(std::chrono::seconds t)
{ cfg_.keepalive_interval = t; return *this; }
[[nodiscard]] PoolConfig build() const { return cfg_; }
private:
PoolConfig cfg_;
};
} // namespace db
db_connection.hpp
#pragma once
#include "db_error.hpp"
#include "db_config.hpp"
#include <string>
#include <vector>
#include <unordered_map>
#include <any>
#include <chrono>
#include <memory>
#include <functional>
#include <variant>
namespace db {
// ── 查询结果行 ────────────────────────────────────────────
using FieldValue = std::variant<
std::monostate, // NULL
std::int64_t,
double,
std::string
>;
struct Row {
std::vector<std::string> columns;
std::vector<FieldValue> values;
std::unordered_map<std::string, std::size_t> index;
const FieldValue& operator[](const std::string& col) const {
auto it = index.find(col);
if (it == index.end())
throw PoolError(PoolErrc::QueryFailed,
std::format("column '{}' not found", col));
return values[it->second];
}
// 类型安全取值
template<typename T>
T get(const std::string& col) const {
return std::get<T>((*this)[col]);
}
std::string get_str(const std::string& col) const {
return get<std::string>(col);
}
std::int64_t get_int(const std::string& col) const {
return get<std::int64_t>(col);
}
double get_double(const std::string& col) const {
return get<double>(col);
}
bool is_null(const std::string& col) const {
return std::holds_alternative<std::monostate>((*this)[col]);
}
};
struct QueryResult {
std::vector<Row> rows;
std::uint64_t affected_rows = 0;
std::uint64_t last_insert_id = 0;
std::size_t field_count = 0;
bool empty() const noexcept { return rows.empty(); }
std::size_t size() const noexcept { return rows.size(); }
const Row& operator[](std::size_t i) const { return rows[i]; }
// 范围遍历
auto begin() const { return rows.begin(); }
auto end() const { return rows.end(); }
};
// ── 预处理语句参数 ────────────────────────────────────────
using Param = std::variant<
std::monostate,
std::int64_t,
double,
std::string,
bool
>;
using Params = std::vector<Param>;
// ════════════════════════════════════════════════════════════
// IConnection:数据库连接抽象接口
// ════════════════════════════════════════════════════════════
class IConnection {
public:
virtual ~IConnection() = default;
// ── 基本操作 ──────────────────────────────────────────
virtual bool connect(const PoolConfig& cfg) = 0;
virtual void disconnect() = 0;
virtual bool is_connected() const = 0;
virtual bool ping() = 0; // 心跳检测
// ── SQL执行 ───────────────────────────────────────────
virtual QueryResult execute(const std::string& sql) = 0;
virtual QueryResult execute(const std::string& sql,
const Params& params) = 0;
// ── 事务 ──────────────────────────────────────────────
virtual void begin_transaction() = 0;
virtual void commit() = 0;
virtual void rollback() = 0;
virtual bool in_transaction() const= 0;
// ── 转义(防注入)────────────────────────────────────
virtual std::string escape(const std::string& s) const = 0;
// ── 元信息 ────────────────────────────────────────────
virtual std::uint64_t last_insert_id() const = 0;
virtual std::uint64_t affected_rows() const = 0;
virtual std::string server_version() const = 0;
virtual std::string error_message() const = 0;
virtual int error_code() const = 0;
// ── 连接池内部使用 ────────────────────────────────────
std::size_t pool_id = 0; // 在池中的ID
bool in_use = false; // 是否被借出
std::chrono::steady_clock::time_point
last_used = std::chrono::steady_clock::now();
std::chrono::steady_clock::time_point
created_at= std::chrono::steady_clock::now();
int fail_count = 0; // 连续失败次数
};
using ConnectionPtr = std::shared_ptr<IConnection>;
using ConnectionFactory = std::function<ConnectionPtr()>;
} // namespace db
mysql_connection.hpp
#pragma once
#include "db_connection.hpp"
#include <print>
// 这里用模拟实现,实际项目接入 mysqlclient / mariadb-connector
// #include <mysql/mysql.h>
namespace db {
// ════════════════════════════════════════════════════════════
// MockMySQLConnection:模拟MySQL连接(可替换为真实实现)
// 真实实现只需把 Mock 换成 mysql_*() 系列函数调用
// ════════════════════════════════════════════════════════════
class MockMySQLConnection : public IConnection {
public:
MockMySQLConnection() = default;
~MockMySQLConnection() override { disconnect(); }
// ── 连接 ──────────────────────────────────────────────
bool connect(const PoolConfig& cfg) override {
cfg_ = cfg;
// 真实代码:
// mysql_ = mysql_init(nullptr);
// mysql_options(mysql_, MYSQL_OPT_CONNECT_TIMEOUT, &timeout);
// connected_ = (mysql_real_connect(mysql_,
// cfg.host.c_str(), cfg.user.c_str(),
// cfg.password.c_str(), cfg.database.c_str(),
// cfg.port, nullptr, 0) != nullptr);
// 模拟:随机10%概率连接失败(测试用)
static int counter = 0;
connected_ = (++counter % 10 != 0);
if (connected_) {
std::println("[Conn#{}] connected to {}", pool_id, cfg.dsn());
}
return connected_;
}
void disconnect() override {
if (connected_) {
// mysql_close(mysql_);
std::println("[Conn#{}] disconnected", pool_id);
connected_ = false;
}
}
bool is_connected() const override { return connected_; }
bool ping() override {
if (!connected_) return false;
// return mysql_ping(mysql_) == 0;
return true; // 模拟ping成功
}
// ── 执行SQL ───────────────────────────────────────────
QueryResult execute(const std::string& sql) override {
return execute(sql, {});
}
QueryResult execute(const std::string& sql,
const Params& params) override
{
if (!connected_)
throw PoolError(PoolErrc::InvalidConn);
last_used = std::chrono::steady_clock::now();
// 真实代码:
// MYSQL_STMT* stmt = mysql_stmt_init(mysql_);
// mysql_stmt_prepare(stmt, sql.c_str(), sql.size());
// ... 绑定参数 ...
// mysql_stmt_execute(stmt);
// ... 读取结果集 ...
// 模拟返回结果
QueryResult result;
if (sql.find("SELECT") != std::string::npos ||
sql.find("select") != std::string::npos)
{
Row row;
row.columns = {"id", "name", "value"};
row.values = {
std::int64_t(1),
std::string("test"),
3.14
};
row.index = {{"id",0}, {"name",1}, {"value",2}};
result.rows.push_back(row);
result.field_count = 3;
}
result.affected_rows = 1;
result.last_insert_id = last_insert_id_;
return result;
}
// ── 事务 ──────────────────────────────────────────────
void begin_transaction() override {
if (in_transaction_)
throw PoolError(PoolErrc::QueryFailed, "already in transaction");
// mysql_query(mysql_, "BEGIN");
in_transaction_ = true;
std::println("[Conn#{}] BEGIN", pool_id);
}
void commit() override {
// mysql_commit(mysql_);
in_transaction_ = false;
std::println("[Conn#{}] COMMIT", pool_id);
}
void rollback() override {
// mysql_rollback(mysql_);
in_transaction_ = false;
std::println("[Conn#{}] ROLLBACK", pool_id);
}
bool in_transaction() const override { return in_transaction_; }
// ── 其他 ──────────────────────────────────────────────
std::string escape(const std::string& s) const override {
// char buf[s.size()*2+1];
// mysql_real_escape_string(mysql_, buf, s.c_str(), s.size());
return s; // 模拟
}
std::uint64_t last_insert_id() const override { return last_insert_id_; }
std::uint64_t affected_rows() const override { return affected_rows_; }
std::string server_version() const override { return "8.0.32-mock"; }
std::string error_message() const override { return err_msg_; }
int error_code() const override { return err_code_; }
private:
PoolConfig cfg_;
bool connected_ = false;
bool in_transaction_= false;
std::uint64_t last_insert_id_= 0;
std::uint64_t affected_rows_ = 0;
std::string err_msg_;
int err_code_ = 0;
};
// ── 工厂函数 ──────────────────────────────────────────────
inline ConnectionFactory make_mysql_factory() {
return []() -> ConnectionPtr {
return std::make_shared<MockMySQLConnection>();
};
}
} // namespace db
pool_monitor.hpp
#pragma once
#include <atomic>
#include <chrono>
#include <string>
#include <format>
#include <print>
#include <numeric>
#include <vector>
#include <mutex>
namespace db {
// ════════════════════════════════════════════════════════════
// PoolStats:连接池统计
// ════════════════════════════════════════════════════════════
struct PoolStats {
// 连接数
std::size_t total_connections = 0;
std::size_t idle_connections = 0;
std::size_t active_connections = 0;
std::size_t pending_requests = 0; // 等待连接的请求数
// 计数器
std::uint64_t total_acquired = 0; // 总获取次数
std::uint64_t total_released = 0; // 总归还次数
std::uint64_t total_created = 0; // 总创建连接数
std::uint64_t total_destroyed = 0; // 总销毁连接数
std::uint64_t total_timeouts = 0; // 超时次数
std::uint64_t total_errors = 0; // 错误次数
std::uint64_t total_reconnects = 0; // 重连次数
// 延迟统计(单位:微秒)
double avg_acquire_us = 0;
double avg_query_us = 0;
double peak_acquire_us = 0;
// 时间
std::chrono::steady_clock::time_point start_time
= std::chrono::steady_clock::now();
[[nodiscard]] std::string format() const {
using namespace std::chrono;
auto uptime = duration_cast<seconds>(
steady_clock::now() - start_time).count();
return std::format(
"╔══════════════ Pool Stats ══════════════╗\n"
"║ Uptime: {:>8} s ║\n"
"║ Connections: total={:>3} idle={:>3} active={:>3}║\n"
"║ Pending: {:>3} requests waiting ║\n"
"║ Acquire: total={:<8} timeout={:<5}║\n"
"║ Created: {:<6} Destroyed: {:<6} ║\n"
"║ Reconnects: {:<6} Errors: {:<6} ║\n"
"║ Avg acquire: {:>8.1f} μs ║\n"
"║ Peak acq: {:>8.1f} μs ║\n"
"╚════════════════════════════════════════╝",
uptime,
total_connections, idle_connections, active_connections,
pending_requests,
total_acquired, total_timeouts,
total_created, total_destroyed,
total_reconnects, total_errors,
avg_acquire_us, peak_acquire_us
);
}
};
// ════════════════════════════════════════════════════════════
// PoolMonitor:线程安全统计收集器
// ════════════════════════════════════════════════════════════
class PoolMonitor {
public:
// ── 记录获取延迟 ──────────────────────────────────────
void record_acquire(std::chrono::microseconds us, bool success) {
if (success) {
++total_acquired_;
// 滑动平均
double d = static_cast<double>(us.count());
double cur = avg_acquire_us_.load(std::memory_order_relaxed);
double cnt = static_cast<double>(total_acquired_.load());
avg_acquire_us_.store(cur + (d - cur) / cnt,
std::memory_order_relaxed);
// 峰值
double peak = peak_acquire_us_.load(std::memory_order_relaxed);
while (d > peak &&
!peak_acquire_us_.compare_exchange_weak(
peak, d, std::memory_order_relaxed))
{}
} else {
++total_timeouts_;
}
}
void record_query(std::chrono::microseconds us) {
++total_queries_;
double d = static_cast<double>(us.count());
double cur = avg_query_us_.load(std::memory_order_relaxed);
double cnt = static_cast<double>(total_queries_.load());
avg_query_us_.store(cur + (d - cur) / cnt,
std::memory_order_relaxed);
}
void record_created() { ++total_created_; }
void record_destroyed() { ++total_destroyed_; }
void record_released() { ++total_released_; }
void record_error() { ++total_errors_; }
void record_reconnect() { ++total_reconnects_; }
// ── 生成快照 ──────────────────────────────────────────
[[nodiscard]] PoolStats snapshot(
std::size_t total, std::size_t idle,
std::size_t pending) const noexcept
{
PoolStats s;
s.total_connections = total;
s.idle_connections = idle;
s.active_connections = total > idle ? total - idle : 0;
s.pending_requests = pending;
s.total_acquired = total_acquired_.load();
s.total_released = total_released_.load();
s.total_created = total_created_.load();
s.total_destroyed = total_destroyed_.load();
s.total_timeouts = total_timeouts_.load();
s.total_errors = total_errors_.load();
s.total_reconnects = total_reconnects_.load();
s.avg_acquire_us = avg_acquire_us_.load();
s.avg_query_us = avg_query_us_.load();
s.peak_acquire_us = peak_acquire_us_.load();
s.start_time = start_time_;
return s;
}
void print_stats(std::size_t total, std::size_t idle,
std::size_t pending) const {
std::println("{}", snapshot(total, idle, pending).format());
}
private:
std::atomic<std::uint64_t> total_acquired_ { 0 };
std::atomic<std::uint64_t> total_released_ { 0 };
std::atomic<std::uint64_t> total_created_ { 0 };
std::atomic<std::uint64_t> total_destroyed_ { 0 };
std::atomic<std::uint64_t> total_timeouts_ { 0 };
std::atomic<std::uint64_t> total_errors_ { 0 };
std::atomic<std::uint64_t> total_reconnects_{ 0 };
std::atomic<std::uint64_t> total_queries_ { 0 };
std::atomic<double> avg_acquire_us_ { 0.0 };
std::atomic<double> avg_query_us_ { 0.0 };
std::atomic<double> peak_acquire_us_ { 0.0 };
std::chrono::steady_clock::time_point start_time_
= std::chrono::steady_clock::now();
};
} // namespace db
connection_pool.hpp(核心)
#pragma once
#include "db_connection.hpp"
#include "pool_monitor.hpp"
#include <deque>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <atomic>
#include <ranges>
#include <stop_token> // C++20
#include <semaphore> // C++20
#include <print>
namespace db {
// ════════════════════════════════════════════════════════════
// ConnectionPool:核心连接池
// ════════════════════════════════════════════════════════════
class ConnectionPool {
public:
// ── 构造/析构 ─────────────────────────────────────────
explicit ConnectionPool(PoolConfig cfg,
ConnectionFactory factory)
: cfg_(std::move(cfg))
, factory_(std::move(factory))
, capacity_sem_(cfg_.max_connections) // 信号量控制最大连接数
{
if (!cfg_.valid())
throw PoolError(PoolErrc::ConnectFailed, "invalid config");
// 预建初始连接
for (std::size_t i = 0; i < cfg_.initial_connections; ++i) {
if (auto conn = create_connection()) {
std::scoped_lock lock(mutex_);
idle_queue_.push_back(std::move(conn));
}
}
// 启动后台维护线程(C++20 jthread + stop_token)
maintenance_thread_ = std::jthread([this](std::stop_token st) {
maintenance_loop(st);
});
std::println("[Pool] started: dsn={} init={}/max={}",
cfg_.dsn(),
idle_queue_.size(),
cfg_.max_connections);
}
~ConnectionPool() {
close();
}
// ── 获取连接(阻塞直到超时)──────────────────────────
[[nodiscard]] ConnectionPtr acquire() {
if (closed_.load(std::memory_order_acquire))
throw PoolError(PoolErrc::PoolClosed);
auto t0 = std::chrono::steady_clock::now();
++pending_requests_;
// 用RAII确保 pending_requests_ 正确递减
struct PendingGuard {
std::atomic<std::size_t>& ref;
~PendingGuard() { --ref; }
} guard{ pending_requests_ };
// ① 先尝试从idle队列直接取(无锁快路径)
{
std::unique_lock lock(mutex_);
if (!idle_queue_.empty()) {
return pop_idle(lock, t0);
}
}
// ② idle队列空,尝试创建新连接(信号量限制上限)
// try_acquire_for:等待最多 acquire_timeout
if (capacity_sem_.try_acquire_for(cfg_.acquire_timeout)) {
// 成功获取信号量 → 可以新建连接
auto conn = create_connection();
if (conn) {
record_acquire(t0, true);
conn->in_use = true;
return conn;
}
// 建立失败归还信号量
capacity_sem_.release();
}
// ③ 等待其他线程归还连接
{
std::unique_lock lock(mutex_);
bool got = cond_.wait_for(lock, cfg_.acquire_timeout,
[this] {
return !idle_queue_.empty() ||
closed_.load(std::memory_order_relaxed);
});
if (closed_.load(std::memory_order_relaxed))
throw PoolError(PoolErrc::PoolClosed);
if (!got) {
// 超时
monitor_.record_acquire(elapsed_us(t0), false);
throw PoolError(PoolErrc::Timeout,
std::format("waited {}ms, pool size={}",
cfg_.acquire_timeout.count(),
total_count_.load()));
}
return pop_idle(lock, t0);
}
}
// ── 归还连接 ──────────────────────────────────────────
void release(ConnectionPtr conn) {
if (!conn) return;
conn->in_use = false;
conn->last_used= std::chrono::steady_clock::now();
// 事务泄漏检测
if (conn->in_transaction()) {
std::println("[Pool] WARN: connection#{} released "
"with active transaction, rolling back",
conn->pool_id);
try { conn->rollback(); } catch (...) {}
}
// 健康检查
if (!conn->is_connected()) {
handle_bad_connection(conn);
return;
}
// 归还到idle队列
{
std::scoped_lock lock(mutex_);
idle_queue_.push_back(std::move(conn));
}
monitor_.record_released();
cond_.notify_one();
}
// ── 关闭连接池 ────────────────────────────────────────
void close() {
bool expected = false;
if (!closed_.compare_exchange_strong(expected, true))
return; // 已关闭
// 停止后台线程
maintenance_thread_.request_stop();
cond_.notify_all();
maintenance_thread_.join();
// 断开所有连接
std::scoped_lock lock(mutex_);
for (auto& conn : idle_queue_) {
conn->disconnect();
monitor_.record_destroyed();
}
idle_queue_.clear();
std::println("[Pool] closed, destroyed {} connections",
total_count_.load());
total_count_ = 0;
}
// ── 统计 ──────────────────────────────────────────────
[[nodiscard]] PoolStats stats() const {
std::scoped_lock lock(mutex_);
return monitor_.snapshot(
total_count_.load(),
idle_queue_.size(),
pending_requests_.load());
}
void print_stats() const {
std::scoped_lock lock(mutex_);
monitor_.print_stats(
total_count_.load(),
idle_queue_.size(),
pending_requests_.load());
}
// ── 查询(直接执行,自动管理连接)───────────────────
QueryResult query(const std::string& sql,
const Params& params = {})
{
auto conn = acquire();
struct ReleaseGuard {
ConnectionPool& pool;
ConnectionPtr conn;
~ReleaseGuard() { pool.release(std::move(conn)); }
} guard{ *this, conn };
auto t0 = std::chrono::steady_clock::now();
try {
auto result = conn->execute(sql, params);
monitor_.record_query(elapsed_us(t0));
return result;
} catch (...) {
monitor_.record_error();
throw;
}
}
// ── 事务(RAII封装)──────────────────────────────────
template<std::invocable<IConnection&> Fn>
auto transaction(Fn&& fn) -> std::invoke_result_t<Fn, IConnection&>
{
auto conn = acquire();
struct ReleaseGuard {
ConnectionPool& pool;
ConnectionPtr conn;
~ReleaseGuard() { pool.release(std::move(conn)); }
} guard{ *this, conn };
conn->begin_transaction();
try {
auto result = std::forward<Fn>(fn)(*conn);
conn->commit();
return result;
} catch (...) {
try { conn->rollback(); } catch (...) {}
monitor_.record_error();
throw;
}
}
// ── 配置 ──────────────────────────────────────────────
const PoolConfig& config() const noexcept { return cfg_; }
private:
// ── 创建新连接 ────────────────────────────────────────
[[nodiscard]] ConnectionPtr create_connection() {
for (int attempt = 0;
attempt < cfg_.max_reconnect_attempts; ++attempt)
{
try {
auto conn = factory_();
conn->pool_id = ++next_id_;
if (conn->connect(cfg_)) {
++total_count_;
monitor_.record_created();
return conn;
}
} catch (const std::exception& e) {
std::println("[Pool] connect attempt {}/{} failed: {}",
attempt + 1, cfg_.max_reconnect_attempts, e.what());
}
if (attempt + 1 < cfg_.max_reconnect_attempts) {
std::this_thread::sleep_for(cfg_.reconnect_delay);
}
}
monitor_.record_error();
return nullptr;
}
// ── 从idle队列弹出连接 ────────────────────────────────
[[nodiscard]] ConnectionPtr pop_idle(
std::unique_lock<std::mutex>& lock,
std::chrono::steady_clock::time_point t0)
{
auto conn = std::move(idle_queue_.front());
idle_queue_.pop_front();
lock.unlock();
conn->in_use = true;
record_acquire(t0, true);
return conn;
}
// ── 处理坏连接 ────────────────────────────────────────
void handle_bad_connection(ConnectionPtr& conn) {
std::println("[Pool] Conn#{} is bad, reconnecting...",
conn->pool_id);
monitor_.record_reconnect();
conn->disconnect();
// 尝试重连
if (conn->connect(cfg_)) {
std::scoped_lock lock(mutex_);
idle_queue_.push_back(std::move(conn));
cond_.notify_one();
} else {
// 重连失败,销毁连接并释放槽位
--total_count_;
monitor_.record_destroyed();
capacity_sem_.release();
// 尝试补充最小连接
ensure_min_connections();
}
}
// ── 确保最小连接数 ────────────────────────────────────
void ensure_min_connections() {
std::size_t current;
{
std::scoped_lock lock(mutex_);
current = total_count_.load() ;
}
while (current < cfg_.min_connections) {
if (!capacity_sem_.try_acquire()) break;
auto conn = create_connection();
if (conn) {
std::scoped_lock lock(mutex_);
idle_queue_.push_back(std::move(conn));
cond_.notify_one();
++current;
} else {
capacity_sem_.release();
break;
}
}
}
// ── 后台维护线程 ──────────────────────────────────────
void maintenance_loop(std::stop_token st) {
using namespace std::chrono;
while (!st.stop_requested()) {
// 每 keepalive_interval 执行一次
std::this_thread::sleep_for(
std::min(cfg_.keepalive_interval,
duration_cast<seconds>(cfg_.idle_timeout / 4)));
if (st.stop_requested()) break;
// ① 心跳 + 空闲超时检查
heartbeat_check();
// ② 保持最小连接数
ensure_min_connections();
// ③ 定期打印统计(调试模式)
// print_stats();
}
}
void heartbeat_check() {
using namespace std::chrono;
auto now = steady_clock::now();
std::vector<ConnectionPtr> bad_conns;
std::vector<ConnectionPtr> expired_conns;
{
std::scoped_lock lock(mutex_);
std::deque<ConnectionPtr> survivors;
for (auto& conn : idle_queue_) {
auto idle_dur = duration_cast<seconds>(
now - conn->last_used);
// 空闲超时 → 标记为过期
if (idle_dur > cfg_.idle_timeout &&
total_count_.load() > cfg_.min_connections)
{
expired_conns.push_back(std::move(conn));
continue;
}
// 心跳检测
if (cfg_.enable_health_check) {
if (!conn->ping()) {
bad_conns.push_back(std::move(conn));
continue;
}
}
conn->last_used = now; // 更新心跳时间
survivors.push_back(std::move(conn));
}
idle_queue_ = std::move(survivors);
}
// 处理过期连接(超时关闭)
for (auto& conn : expired_conns) {
std::println("[Pool] Conn#{} idle timeout, closing",
conn->pool_id);
conn->disconnect();
--total_count_;
monitor_.record_destroyed();
capacity_sem_.release();
}
// 处理坏连接(尝试重连)
for (auto& conn : bad_conns) {
handle_bad_connection(conn);
}
}
// ── 工具函数 ──────────────────────────────────────────
void record_acquire(
std::chrono::steady_clock::time_point t0,
bool success)
{
monitor_.record_acquire(elapsed_us(t0), success);
}
static std::chrono::microseconds elapsed_us(
std::chrono::steady_clock::time_point t0) noexcept
{
return std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - t0);
}
// ── 成员变量 ──────────────────────────────────────────
PoolConfig cfg_;
ConnectionFactory factory_;
mutable std::mutex mutex_;
std::condition_variable cond_;
std::deque<ConnectionPtr> idle_queue_; // 空闲连接队列
// C++20 计数信号量:控制最大连接数
std::counting_semaphore<> capacity_sem_;
std::atomic<std::size_t> total_count_ { 0 };
std::atomic<std::size_t> pending_requests_{ 0 };
std::atomic<std::size_t> next_id_ { 0 };
std::atomic<bool> closed_ { false };
PoolMonitor monitor_;
// C++20 jthread:析构时自动join + stop
std::jthread maintenance_thread_;
};
} // namespace db
pool_guard.hpp
#pragma once
#include "connection_pool.hpp"
namespace db {
// ════════════════════════════════════════════════════════════
// ConnectionGuard:RAII连接守卫
// 离开作用域自动归还连接
// ════════════════════════════════════════════════════════════
class ConnectionGuard {
public:
ConnectionGuard(ConnectionPool& pool)
: pool_(&pool)
, conn_(pool.acquire())
{}
~ConnectionGuard() {
if (pool_ && conn_)
pool_->release(std::move(conn_));
}
// 不可拷贝,可移动
ConnectionGuard(const ConnectionGuard&) = delete;
ConnectionGuard& operator=(const ConnectionGuard&) = delete;
ConnectionGuard(ConnectionGuard&& o) noexcept
: pool_(o.pool_), conn_(std::move(o.conn_)) {
o.pool_ = nullptr;
}
// ── 透明访问 ──────────────────────────────────────────
IConnection* operator->() const noexcept { return conn_.get(); }
IConnection& operator*() const noexcept { return *conn_; }
// ── SQL快捷方式 ───────────────────────────────────────
QueryResult execute(const std::string& sql,
const Params& params = {}) {
return conn_->execute(sql, params);
}
// ── 事务支持 ──────────────────────────────────────────
void begin() { conn_->begin_transaction(); }
void commit() { conn_->commit(); }
void rollback() { conn_->rollback(); }
// ── 手动提前归还 ──────────────────────────────────────
void release() {
if (pool_ && conn_) {
pool_->release(std::move(conn_));
pool_ = nullptr;
}
}
private:
ConnectionPool* pool_;
ConnectionPtr conn_;
};
// ════════════════════════════════════════════════════════════
// TransactionGuard:事务RAII守卫
// 异常时自动回滚,成功需手动commit
// ════════════════════════════════════════════════════════════
class TransactionGuard {
public:
explicit TransactionGuard(ConnectionGuard& guard)
: guard_(guard)
{
guard_.begin();
}
~TransactionGuard() {
if (!committed_ && !rolled_back_) {
try { guard_.rollback(); } catch (...) {}
}
}
void commit() {
guard_.commit();
committed_ = true;
}
void rollback() {
guard_.rollback();
rolled_back_ = true;
}
TransactionGuard(const TransactionGuard&) = delete;
TransactionGuard& operator=(const TransactionGuard&) = delete;
private:
ConnectionGuard& guard_;
bool committed_ = false;
bool rolled_back_ = false;
};
} // namespace db
db_pool.hpp(统一入口)
#pragma once
#include "connection_pool.hpp"
#include "pool_guard.hpp"
#include "mysql_connection.hpp"
namespace db {
// ── 全局单例连接池 ────────────────────────────────────────
class DB {
public:
static void init(PoolConfig cfg) {
instance() = std::make_unique<ConnectionPool>(
std::move(cfg), make_mysql_factory());
}
static ConnectionPool& pool() {
auto& inst = instance();
if (!inst) throw PoolError(PoolErrc::PoolClosed,
"DB::init() not called");
return *inst;
}
static ConnectionGuard conn() {
return ConnectionGuard(pool());
}
static QueryResult query(const std::string& sql,
const Params& params = {}) {
return pool().query(sql, params);
}
template<std::invocable<IConnection&> Fn>
static auto transaction(Fn&& fn) {
return pool().transaction(std::forward<Fn>(fn));
}
static void shutdown() { instance().reset(); }
private:
static std::unique_ptr<ConnectionPool>& instance() {
static std::unique_ptr<ConnectionPool> inst;
return inst;
}
};
} // namespace db
三、完整使用示例
#include "db_pool.hpp"
#include <thread>
#include <vector>
#include <print>
using namespace db;
using namespace std::chrono_literals;
// ── 基本使用 ──────────────────────────────────────────────
void basic_example() {
std::println("\n=== 基本使用 ===");
// 1. ConnectionGuard(推荐)
{
ConnectionGuard g(DB::pool());
auto result = g.execute("SELECT id, name FROM users WHERE id = ?",
{ std::int64_t(1) });
for (const auto& row : result) {
std::println("id={} name={}",
row.get_int("id"),
row.get_str("name"));
}
} // ← 连接自动归还
// 2. 直接查询(更简洁)
auto result = DB::query("SELECT COUNT(*) as cnt FROM orders");
std::println("orders count = {}", result[0].get_int("cnt"));
}
// ── 事务使用 ──────────────────────────────────────────────
void transaction_example() {
std::println("\n=== 事务 ===");
// 方式1:lambda事务
DB::transaction([](IConnection& conn) -> void {
conn.execute("INSERT INTO accounts(name, balance) VALUES(?,?)",
{ std::string("Alice"), std::int64_t(1000) });
conn.execute("UPDATE accounts SET balance = balance - ? WHERE id=?",
{ std::int64_t(100), std::int64_t(1) });
conn.execute("UPDATE accounts SET balance = balance + ? WHERE id=?",
{ std::int64_t(100), std::int64_t(2) });
// 抛出异常 → 自动rollback
// throw std::runtime_error("test rollback");
});
// 方式2:TransactionGuard(更细粒度控制)
{
ConnectionGuard cg(DB::pool());
TransactionGuard tg(cg); // BEGIN
try {
cg.execute("DELETE FROM temp_data WHERE expired = 1");
cg.execute("INSERT INTO logs(msg) VALUES(?)",
{ std::string("cleaned") });
tg.commit(); // COMMIT
} catch (...) {
// tg析构时自动 ROLLBACK
throw;
}
}
}
// ── 多线程并发测试 ─────────────────────────────────────────
void concurrent_test() {
std::println("\n=== 并发测试 ===");
constexpr int THREADS = 20;
constexpr int OPS = 50;
std::vector<std::thread> workers;
std::atomic<int> success{ 0 }, failed{ 0 };
for (int t = 0; t < THREADS; ++t) {
workers.emplace_back([&, t] {
for (int i = 0; i < OPS; ++i) {
try {
auto res = DB::query(
"SELECT ? + ? AS result",
{ std::int64_t(t), std::int64_t(i) });
++success;
} catch (const PoolError& e) {
++failed;
std::println("[T{}] error: {}", t, e.what());
}
// 随机模拟业务处理时间
if (i % 5 == 0)
std::this_thread::sleep_for(1ms);
}
});
}
for (auto& w : workers) w.join();
std::println("success={} failed={}", success.load(), failed.load());
DB::pool().print_stats();
}
// ── 连接池压力测试 ────────────────────────────────────────
void stress_test() {
std::println("\n=== 压力测试 ===");
constexpr int N = 1000;
auto t0 = std::chrono::high_resolution_clock::now();
for (int i = 0; i < N; ++i) {
DB::query("SELECT 1");
}
auto t1 = std::chrono::high_resolution_clock::now();
auto us = std::chrono::duration_cast<
std::chrono::microseconds>(t1 - t0).count();
std::println("{}次查询耗时 {} μs,平均 {:.1f} μs/次",
N, us, double(us) / N);
}
int main() {
// ── 初始化连接池 ──────────────────────────────────────
auto cfg = PoolConfigBuilder{}
.host("127.0.0.1")
.port(3306)
.user("root")
.password("123456")
.database("testdb")
.min_conn(2)
.max_conn(10)
.acquire_timeout(3s)
.idle_timeout(300s)
.keepalive(60s)
.build();
DB::init(cfg);
// ── 运行示例 ──────────────────────────────────────────
basic_example();
transaction_example();
concurrent_test();
stress_test();
DB::pool().print_stats();
DB::shutdown();
return 0;
}
四、原理总结
┌──────────────────────────────────────────────────────────────┐
│ acquire() 流程 │
│ │
│ ┌─────────┐ idle队列有连接? │
│ │ 请求到来 │──────── YES ──────► 直接取走 O(1) ◄── 快路径 │
│ └─────────┘ │
│ │ NO │
│ ▼ │
│ 总连接数 < max? │
│ ──── YES ──► 信号量acquire ──► 新建连接 ──► 返回 │
│ │ NO │
│ ▼ │
│ 条件变量等待(带超时) ──► 被release()唤醒 ──► 取连接 │
│ │ 超时 │
│ ▼ │
│ 抛出 PoolError::Timeout │
└──────────────────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────────────────┐
│ 后台维护线程 │
│ │
│ 每 keepalive_interval 秒执行: │
│ ① ping所有空闲连接 → 失败则重连 → 重连失败则销毁 │
│ ② 检查空闲超时 → 超过idle_timeout且数量>min → 关闭 │
│ ③ 当前连接数 < min_connections → 补充新连接 │
└──────────────────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────────────────┐
│ C++20 特性使用 │
├─────────────────────┬────────────────────────────────────────┤
│ std::jthread │ 后台线程,析构自动stop+join │
│ std::stop_token │ 优雅停止后台线程 │
│ std::counting_ │ 控制最大连接数上限(替代手写计数器) │
│ semaphore │ │
│ std::format/println │ 结构化日志输出 │
│ std::atomic<double> │ 无锁统计均值 │
│ concepts/ranges │ 类型约束与范围操作 │
└─────────────────────┴────────────────────────────────────────┘
更多推荐

所有评论(0)