C++ 工作窃取线程池(WorkStealingPool)深度解析:原理、实现与实践
·
前言
在高并发多核时代,传统的固定 / 缓存线程池在负载不均时会出现大量线程空转、CPU 利用率低的问题。而工作窃取(Work Stealing)算法正是解决这一痛点的银弹。
一、为什么需要工作窃取线程池?
传统线程池通常使用一个全局任务队列,所有线程共享访问,这带来了两个核心问题:
- 激烈的锁竞争:多个线程同时从同一个队列取任务,频繁加锁 / 解锁导致性能损耗。
- 负载不均衡:如果任务处理时间差异很大,会出现部分线程忙死、部分线程闲死的情况,多核资源严重浪费。
工作窃取线程池的核心思想是:
- 每个线程拥有独立的任务队列,优先处理自己队列里的任务,减少锁竞争。
- 空闲线程主动 “偷取” 繁忙线程的任务,通过从其他线程的队列尾部窃取任务,实现负载均衡。
这种设计可以显著提高线程利用率,减少等待时间,是现代多核并行计算的首选方案。
二、核心设计与架构概览
核心架构如下:
- 多桶任务队列(
WSyncQueue):为每个工作线程维护一个独立的双端队列(deque)。 - 工作窃取逻辑:线程优先从自己的队列取任务,队列为空时,随机选择其他线程的队列尾部窃取任务。
- 线程安全控制:通过互斥锁和条件变量保证多线程下的安全访问。
- 优雅关闭机制:确保线程池停止时,所有任务都能执行完毕。
三、关键组件:多桶同步队列 WSyncQueue
WSyncQueue 是整个线程池的基石,它维护了多个独立的任务队列,支持任务的提交、获取和窃取。
3.1 核心数据结构
#include <vector>
#include <deque>
#include <mutex>
#include <condition_variable>
using namespace std;
#ifndef SYNCQUEUE_2_HPP
#define SYNCQUEUE_2_HPP
namespace tulun
{
template <class T>
class WSyncQueue
{
private:
std::vector<std::deque<T>> m_taskQueues;// 每个线程一个任务队列
size_t m_bucketSize; // 8;队列数量(等于线程数)
size_t m_maxSize; // // 单个队列的最大容量
mutable std::mutex m_mutex;// 全局互斥锁
std::condition_variable m_notEmpty;// 队列非空条件变量
std::condition_variable m_notFull;// 队列非满条件变量
size_t m_waitTime;
bool m_needStop; // 停止标志
private:
bool IsFull(const size_t index) const
{
bool full = m_taskQueues[index].size() >= m_maxSize;
return full;
}
bool IsEmpty(const size_t index) const
{
return m_taskQueues[index].empty();
}
template <class F>
int Add(F &&task, const size_t index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notFull.wait_for(
locker,
std::chrono::milliseconds(m_waitTime),
[this, index]
{ return m_needStop || !IsFull(index); });
if (m_needStop)
{
return 1;
}
if (!waitret)//(超时/满了)
{
return 2;
}
m_taskQueues[index].push_back(std::forward<F>(task));
m_notEmpty.notify_all();
return 0;
}
size_t GetTotalTaskSize() const
{
size_t total = 0;
for (const auto &que : m_taskQueues)
{
total += que.size();
}
return total;
}
public:
WSyncQueue(int bucketsize = 8, int maxquesize = 500, size_t timeout = 1)
: m_bucketSize(bucketsize),
m_maxSize(maxquesize),
m_waitTime(timeout),
m_needStop(false)
{
// 初始化指定数量的任务队列
////开空间+创建对象(允许下标访问)
m_taskQueues.resize(m_bucketSize);
}
~WSyncQueue()
{
WaitStop();
}
int Put(const T &task, const size_t index)
{
return Add(task, index);
}
int Put(T &&task, const size_t index)
{
return Add(std::forward<T>(task), index);
}
int Take(std::deque<T> &tque, const size_t index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notEmpty.wait_for(
locker,
std::chrono::milliseconds(m_waitTime),
[this, index]() -> bool
{
return m_needStop || !IsEmpty(index);
});
if (m_needStop)
{
return 1;
}
if (!waitret)//(超时/空)
{
return 2;
}
//移动资源,对象还在
tque = std::move(m_taskQueues[index]);
m_notFull.notify_all();
return 0;
}
int Take(T &task, const size_t index)
{
std::unique_lock<std::mutex> locker(m_mutex);
// while (!m_needStop && IsEmpty())
// {
// m_notEmpty.wait(locker);
// }
bool waitret = m_notEmpty.wait_for(
locker,
std::chrono::milliseconds(m_waitTime),
[this, index]() -> bool
{
return m_needStop || !IsEmpty(index);
});
if (m_needStop)
{
return 1;
}
if (!waitret)
{
return 2;
}
task = m_taskQueues[index].front();
m_taskQueues[index].pop_front();
m_notFull.notify_all();
return 0;
}
void WaitStop()
{
std::unique_lock<std::mutex> locker(m_mutex);
while (GetTotalTaskSize() != 0)
{
m_notFull.wait_for(locker, std::chrono::seconds(m_waitTime));
}
// 设置停止标志
m_needStop = true;
// 唤醒所有等待的线程
m_notEmpty.notify_all();
m_notFull.notify_all();
}
size_t TotalTaskSize() const
{
std::unique_lock<std::mutex> locker(m_mutex);
size_t total = 0;
for (const auto &que : m_taskQueues)
{
total += que.size();
}
return total;
}
void PrintTaskInfo() const
{
std::unique_lock<std::mutex> locker(m_mutex);
for (size_t i = 0; i < m_taskQueues.size(); ++i)
{
printf("buck: %zu => %zu count \n", i, m_taskQueues[i].size());
}
}
};
} // namespace tulun
#endif
3.2 核心方法解析
任务提交(Add/Put)
- 提交任务时,线程根据自身 ID 选择一个专属队列,将任务加入尾部。
- 如果队列已满,则等待或直接在当前线程执行任务(降级策略)。
template <class F>
int Add(F &&task, const size_t index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notFull.wait_for(
locker,
std::chrono::milliseconds(m_waitTime),
[this, index]
{ return m_needStop || !IsFull(index); });
if (m_needStop)
{
return 1;
}
if (!waitret)//(超时/满了)
{
return 2;
}
m_taskQueues[index].push_back(std::forward<F>(task));
m_notEmpty.notify_all();
return 0;
}
任务获取(Take)
- 本地获取:线程优先从自己队列的头部获取任务。
- 窃取任务:当本地队列为空时,从其他线程队列的尾部批量窃取任务,减少锁竞争次数。
int Take(std::deque<T> &tque, const size_t index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notEmpty.wait_for(
locker,
std::chrono::milliseconds(m_waitTime),
[this, index]() -> bool
{
return m_needStop || !IsEmpty(index);
});
if (m_needStop)
{
return 1;
}
if (!waitret)//(超时/空)
{
return 2;
}
//移动资源,对象还在
tque = std::move(m_taskQueues[index]);
m_notFull.notify_all();
return 0;
}
int Take(T &task, const size_t index)
{
std::unique_lock<std::mutex> locker(m_mutex);
// while (!m_needStop && IsEmpty())
// {
// m_notEmpty.wait(locker);
// }
bool waitret = m_notEmpty.wait_for(
locker,
std::chrono::milliseconds(m_waitTime),
[this, index]() -> bool
{
return m_needStop || !IsEmpty(index);
});
if (m_needStop)
{
return 1;
}
if (!waitret)
{
return 2;
}
task = m_taskQueues[index].front();
m_taskQueues[index].pop_front();
m_notFull.notify_all();
return 0;
}
优雅关闭(WaitStop)
- 等待所有任务处理完毕后,设置停止标志,唤醒所有等待线程。
void WaitStop()
{
std::unique_lock<std::mutex> locker(m_mutex);
while (GetTotalTaskSize() != 0)
{
m_notFull.wait_for(locker, std::chrono::seconds(m_waitTime));
}
// 设置停止标志
m_needStop = true;
// 唤醒所有等待的线程
m_notEmpty.notify_all();
m_notFull.notify_all();
}
四、工作窃取线程池 WorkStealingThreadPool
这是线程池的核心,它负责管理工作线程、调度任务和实现窃取逻辑。
4.1 核心成员变量
class WorkStealingThreadPool
{
public:
using Task = std::function<void(void)>;
private:
size_t m_numThreads; // 线程数量
tulun::WSyncQueue<Task> m_queue; // 多桶任务队列
std::vector<std::shared_ptr<std::thread>> m_threadgroup; // 线程组
std::atomic<bool> m_running; // 运行标志
std::once_flag m_flag; // 保证Stop只调用一次
// ...
};
4.2 线程启动与任务分配(Start)
- 初始化时创建指定数量的工作线程,并为每个线程分配一个独立的任务队列。
4.3 核心窃取逻辑(RunInThread)
这是工作窃取的灵魂,每个线程都在这个循环中执行任务:
void RunInThread(const size_t index)
{
while (m_running)
{
std::deque<Task> taskque;
// 1. 优先从自己的队列获取任务
m_queue.PrintTaskInfo();
if (m_queue.Take(taskque, index) == 0)
{
for (auto &task : taskque)
{
if (task)
{
task();
}
}
}
else
{
// 2. 本地队列空了,尝试从其他线程窃取任务
size_t i = threadIndex();
if (i != index && m_queue.Take(taskque, i) == 0)
{
for (auto &task : taskque)
{
if (task)
{
task();
}
}
}
}
}
}
- 本地执行:优先处理自己队列里的任务,无需竞争。
- 窃取执行:本地队列为空时,随机选择其他线程的队列进行窃取,实现负载均衡。
4.4 任务提交接口
支持两种任务提交方式:
- 普通任务(无返回值)
void AddTask(Task &&task) { if (m_queue.Put(std::move(task), threadIndex()) != 0) { LOG_ERROR << "Add task run task"; task();// 自己在主线程执行(降级策略) } } void AddTask(const Task &task) { if (m_queue.Put(task, threadIndex()) != 0) { LOG_ERROR << "Add task run task"; task(); } } - 异步任务(带返回值)利用
std::packaged_task和std::future实现异步结果获取:template <class Func, class... Args> auto submit(Func &&func, Args &&...args) { using RetType = decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...)); auto task = std::make_shared<std::packaged_task<RetType()>>( [func = std::forward<Func>(func), ... args = std::forward<Args>(args)]() mutable { return std::invoke(func, args...); }); std::future<RetType> result = task->get_future(); if (m_queue.Put([task]() -> void { (*task)(); }, threadIndex()) != 0) { LOG_ERROR << "Add task run task"; (*task)(); } return result; }
4.5 优雅关闭(StopThreadGroup)
- 调用
m_queue.WaitStop()等待所有任务处理完毕。 - 设置
m_running = false,让所有线程退出循环。 join所有工作线程,确保它们安全退出。
4.6 WorkStealingThreadPool.hpp
#include "Logger.hpp"
#include "SyncQueue_2.hpp"
#include <functional>
#include <future>
#include <memory>
#include <deque>
#include <atomic>
using namespace std;
//防止头文件被重复包含(重复编译),避免程序报错、崩溃!
#ifndef WORKSTEALINGPOOL_HPP
#define WORKSTEALINGPOOL_HPP
namespace tulun
{
class WorkStealingThreadPool
{
public:
using Task = std::function<void(void)>;// 任务类型:无参无返回值函数
private:
size_t m_numThreads;// 线程数量
tulun::WSyncQueue<Task> m_queue;// 多桶任务队列(N个线程对应N个桶)
std::vector<std::shared_ptr<std::thread>> m_threadgroup;// 线程组
std::atomic<bool> m_running;
std::once_flag m_flag;// 保证Stop只调用一次
size_t threadIndex()
{
static size_t num = 0;
return num++ % m_numThreads;
}
void Start(int numthreads)
{
m_running = true;
for (int i = 0; i < numthreads; ++i)
{
std::shared_ptr<std::thread> tha = std::make_shared<std::thread>(
&WorkStealingThreadPool::RunInThread, this, i);
m_threadgroup.push_back(std::move(tha));
}
}
void RunInThread(const size_t index)
{
while (m_running)
{
std::deque<Task> taskque;
// 1. 优先从自己的队列获取任务
m_queue.PrintTaskInfo();
if (m_queue.Take(taskque, index) == 0)
{
for (auto &task : taskque)
{
if (task)
{
task();
}
}
}
else
{
// 2. 本地队列空了,尝试从其他线程窃取任务
size_t i = threadIndex();
if (i != index && m_queue.Take(taskque, i) == 0)
{
for (auto &task : taskque)
{
if (task)
{
task();
}
}
}
}
}
}
void StopThreadGroup()
{
m_queue.WaitStop();
m_running = false;
//遍历所有线程 → 等待它们全部执行完 → 再关闭线程池
for (auto &tha:m_threadgroup)
{
//auto &tha:代表每一个线程对象,tha 是 shared_ptr
//一个线程只有 正在运行 / 已结束但未 join 才是 joinable
if (tha && tha->joinable())
{
tha->join();//阻塞等待这个线程执行完毕!
}
}
m_threadgroup.clear();
}
public:
WorkStealingThreadPool(const size_t qusize = 500, const size_t numthread = 8)
: m_numThreads(numthread),
m_queue(numthread, qusize),
m_running(false)
{
Start(m_numThreads);
}
~WorkStealingThreadPool()
{
if (m_running)
{
Stop();
}
}
void Stop()
{
// std::call_once(m_flag,[this]{StopThreadGroup();});
std::call_once(m_flag, std::bind(&WorkStealingThreadPool::StopThreadGroup, this));
}
void AddTask(Task &&task)
{
if (m_queue.Put(std::move(task), threadIndex()) != 0)
{
LOG_ERROR << "Add task run task";
task();// 自己在主线程执行(降级策略)
}
}
void AddTask(const Task &task)
{
if (m_queue.Put(task, threadIndex()) != 0)
{
LOG_ERROR << "Add task run task";
task();
}
}
template <class Func, class... Args>
auto submit(Func &&func, Args &&...args)
{
using RetType = decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...));
auto task = std::make_shared<std::packaged_task<RetType()>>(
[func = std::forward<Func>(func), ... args = std::forward<Args>(args)]() mutable
{
return std::invoke(func, args...);
});
std::future<RetType> result = task->get_future();
if (m_queue.Put([task]() -> void
{ (*task)(); }, threadIndex()) != 0)
{
LOG_ERROR << "Add task run task";
(*task)();
}
return result;
}
};
} // namespace tulun
#endif
五、性能测试与适用场景
5.1 性能测试
test_funb 函数通过随机数生成、排序和文件保存三个阶段,对比了工作窃取线程池的性能:
#include <random>
#include <iostream>
using namespace std;
#include "Timestamp.hpp"
#include "Logger.hpp"
#include "WorkStealingThreadPool.hpp"
static const size_t row = 10000;
static const size_t col = 100000;
std::vector<std::vector<int>> iveca, ivecb, ivecc, ivecd;
void RandInit(std::vector<int> &ivec)
{
std::random_device rd; // 真随机种子源
std::mt19937 gen(rd()); // 随机数引擎
std::uniform_int_distribution<int> dis(0, 100000); // 分布范围 [1, 100]
// int random_num = dis(gen);
// ivec.reserve(col);
for (int i = 0; i < col; ++i)
{
// ivec.push_back(rand() % 10000);
ivec.push_back(dis(gen));
}
}
void SortVec(std::vector<int> &ivec)
{
std::sort(ivec.begin(), ivec.end());
}
void SaveFile(std::vector<int> &ivec, const std::string &filename)
{
FILE *pf = fopen(filename.c_str(), "a");
if (nullptr == pf)
{
LOG_FATAL << "fopen fail \n";
exit(EXIT_FAILURE);
}
for (int i = 0; i < col; ++i)
{
fprintf(pf, "%d ", ivec[i]);
}
fprintf(pf, "\n---------------------\n");
fclose(pf);
pf = nullptr;
}
void test_funb()
{
// 1. 随机数生成(CPU密集型)
// 2. 排序(CPU密集型)
// 3. 文件保存(I/O密集型)
srand(time(nullptr));
cout << "8 线程测试...." << endl;
tulun::Timestamp start, end;
start = tulun::Timestamp::Now();
{
ivecb.resize(row);
tulun::WorkStealingThreadPool mypool(500,8);
for (int i = 0; i < row; ++i)
{
ivecb[i].reserve(col);
// RandInit(ivecb[i]);
mypool.AddTask(std::bind(RandInit, std::ref(ivecb[i])));
}
}
end = tulun::Timestamp::Now();
cout << "RandInit: " << (tulun::diffMicro(end, start)) / 1000000 << "." << (tulun::diffMicro(end, start)) % 1000000 << endl;
start = tulun::Timestamp::Now();
{
tulun::WorkStealingThreadPool mypool(500,8);
for (int i = 0; i < row; ++i)
{
// SortVec(ivecb[i]);
mypool.AddTask(std::bind(SortVec, std::ref(ivecb[i])));
}
}
end = tulun::Timestamp::Now();
cout << "SortVec : " << (tulun::diffMicro(end, start)) / 1000000 << "." << (tulun::diffMicro(end, start)) % 1000000 << endl;
start = tulun::Timestamp::Now();
{
tulun::WorkStealingThreadPool mypool(500,8);
for (int i = 0; i < row; ++i)
{
// SaveFile(ivecb[i], "iveca.txt");
mypool.AddTask(std::bind(SaveFile, std::ref(ivecb[i]), std::string("ivecb.txt")));
}
}
end = tulun::Timestamp::Now();
cout << "SaveFile: " << (tulun::diffMicro(end, start)) / 1000000 << "." << (tulun::diffMicro(end, start)) % 1000000 << endl;
return;
}
int main()
{
test_funb();
return 0;
}
工作窃取线程池在 CPU 密集型任务中表现尤为出色,因为它能有效平衡负载,避免线程空转。
5.2 适用场景
结合工作窃取算法的特性,它最适合以下场景:
- 任务分解型应用:如并行图像处理、数据处理、排序算法,一个大任务可分解为多个独立子任务。
- 递归型任务:如斐波那契计算、归并排序,任务执行过程中会动态生成新的子任务。
- 高吞吐量 CPU 密集型任务:任务执行时间相近、无 I/O 阻塞,需要充分利用多核 CPU。
5.3 注意事项
- 工作窃取算法对 I/O 阻塞型任务 效果不佳,因为线程被阻塞时,其他线程无法窃取其任务。
- 任务量过小时,线程窃取的开销可能大于收益,需根据业务场景选择合适的线程池类型。
六、总结:工作窃取线程池的优势与核心要点
| 特性 | 说明 |
|---|---|
| 核心优势 | 减少锁竞争,实现负载均衡,提高多核利用率 |
| 关键设计 | 每个线程独立队列,优先本地执行,空闲时主动窃取 |
| 适用场景 | CPU 密集型、可分解、高并发任务 |
| 实现要点 | 多桶同步队列、批量窃取任务、优雅关闭机制 |
更多推荐
所有评论(0)