前言

在高并发多核时代,传统的固定 / 缓存线程池在负载不均时会出现大量线程空转、CPU 利用率低的问题。而工作窃取(Work Stealing)算法正是解决这一痛点的银弹。


一、为什么需要工作窃取线程池?

传统线程池通常使用一个全局任务队列,所有线程共享访问,这带来了两个核心问题:

  1. 激烈的锁竞争:多个线程同时从同一个队列取任务,频繁加锁 / 解锁导致性能损耗。
  2. 负载不均衡:如果任务处理时间差异很大,会出现部分线程忙死、部分线程闲死的情况,多核资源严重浪费。

工作窃取线程池的核心思想是:

  • 每个线程拥有独立的任务队列,优先处理自己队列里的任务,减少锁竞争。
  • 空闲线程主动 “偷取” 繁忙线程的任务,通过从其他线程的队列尾部窃取任务,实现负载均衡。

这种设计可以显著提高线程利用率,减少等待时间,是现代多核并行计算的首选方案。


二、核心设计与架构概览

核心架构如下:

  1. 多桶任务队列(WSyncQueue:为每个工作线程维护一个独立的双端队列(deque)。
  2. 工作窃取逻辑:线程优先从自己的队列取任务,队列为空时,随机选择其他线程的队列尾部窃取任务。
  3. 线程安全控制:通过互斥锁和条件变量保证多线程下的安全访问。
  4. 优雅关闭机制:确保线程池停止时,所有任务都能执行完毕。

三、关键组件:多桶同步队列 WSyncQueue

WSyncQueue 是整个线程池的基石,它维护了多个独立的任务队列,支持任务的提交、获取和窃取。

3.1 核心数据结构

#include <vector>
#include <deque>
#include <mutex>
#include <condition_variable>
using namespace std;

#ifndef SYNCQUEUE_2_HPP
#define SYNCQUEUE_2_HPP

namespace tulun
{
    template <class T>
    class WSyncQueue
    {
    private:
        std::vector<std::deque<T>> m_taskQueues;// 每个线程一个任务队列
        size_t m_bucketSize; // 8;队列数量(等于线程数)
        size_t m_maxSize;    // // 单个队列的最大容量
        mutable std::mutex m_mutex;// 全局互斥锁
        std::condition_variable m_notEmpty;// 队列非空条件变量
        std::condition_variable m_notFull;// 队列非满条件变量
        size_t m_waitTime;
        bool m_needStop; // 停止标志

    private:
        bool IsFull(const size_t index) const
        {
            bool full = m_taskQueues[index].size() >= m_maxSize;
            return full;
        }
        bool IsEmpty(const size_t index) const
        {
            return m_taskQueues[index].empty();
        }

        template <class F>
        int Add(F &&task, const size_t index)
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            bool waitret = m_notFull.wait_for(
                locker,
                std::chrono::milliseconds(m_waitTime),
                [this, index]
                { return m_needStop || !IsFull(index); });

            if (m_needStop)
            {
                return 1;
            }
            if (!waitret)//(超时/满了)
            {
                return 2;
            }

            m_taskQueues[index].push_back(std::forward<F>(task));
            m_notEmpty.notify_all();
            return 0;
        }
        size_t GetTotalTaskSize() const
        {
            size_t total = 0;
            for (const auto &que : m_taskQueues)
            {
                total += que.size();
            }
            return total;
        }

    public:
        WSyncQueue(int bucketsize = 8, int maxquesize = 500, size_t timeout = 1)
            : m_bucketSize(bucketsize),
              m_maxSize(maxquesize),
              m_waitTime(timeout),
              m_needStop(false)
        {
            // 初始化指定数量的任务队列
            ////开空间+创建对象(允许下标访问)
            m_taskQueues.resize(m_bucketSize);
        }
        ~WSyncQueue()
        {
            WaitStop();
        }
        int Put(const T &task, const size_t index)
        {
            return Add(task, index);
        }
        int Put(T &&task, const size_t index)
        {
            return Add(std::forward<T>(task), index);
        }
        int Take(std::deque<T> &tque, const size_t index)
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            bool waitret = m_notEmpty.wait_for(
                locker,
                std::chrono::milliseconds(m_waitTime),
                [this, index]() -> bool
                {
                    return m_needStop || !IsEmpty(index);
                });
            if (m_needStop)
            {
                return 1;
            }
            if (!waitret)//(超时/空)
            {
                return 2;
            }
            //移动资源,对象还在
            tque = std::move(m_taskQueues[index]);
            m_notFull.notify_all();
            return 0;
        }
        int Take(T &task, const size_t index)
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            // while (!m_needStop && IsEmpty())
            // {
            //     m_notEmpty.wait(locker);
            // }
            bool waitret = m_notEmpty.wait_for(
                locker,
                std::chrono::milliseconds(m_waitTime),
                [this, index]() -> bool
                {
                    return m_needStop || !IsEmpty(index);
                });
            if (m_needStop)
            {
                return 1;
            }
            if (!waitret)
            {
                return 2;
            }
            task = m_taskQueues[index].front();
            m_taskQueues[index].pop_front();
            m_notFull.notify_all();
            return 0;
        }

        void WaitStop()
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            while (GetTotalTaskSize() != 0)
            {
                m_notFull.wait_for(locker, std::chrono::seconds(m_waitTime));
            }
            // 设置停止标志
            m_needStop = true;
            // 唤醒所有等待的线程
            m_notEmpty.notify_all();
            m_notFull.notify_all();
        }
        size_t TotalTaskSize() const
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            size_t total = 0;
            for (const auto &que : m_taskQueues)
            {
                total += que.size();
            }
            return total;
        }
        void PrintTaskInfo() const
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            for (size_t i = 0; i < m_taskQueues.size(); ++i)
            {
                printf("buck: %zu => %zu count \n", i, m_taskQueues[i].size());
            }
        }
    };
} // namespace tulun

#endif

3.2 核心方法解析

任务提交(Add/Put
  • 提交任务时,线程根据自身 ID 选择一个专属队列,将任务加入尾部。
  • 如果队列已满,则等待或直接在当前线程执行任务(降级策略)。
template <class F>
        int Add(F &&task, const size_t index)
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            bool waitret = m_notFull.wait_for(
                locker,
                std::chrono::milliseconds(m_waitTime),
                [this, index]
                { return m_needStop || !IsFull(index); });

            if (m_needStop)
            {
                return 1;
            }
            if (!waitret)//(超时/满了)
            {
                return 2;
            }

            m_taskQueues[index].push_back(std::forward<F>(task));
            m_notEmpty.notify_all();
            return 0;
        }
任务获取(Take
  • 本地获取:线程优先从自己队列的头部获取任务。
  • 窃取任务:当本地队列为空时,从其他线程队列的尾部批量窃取任务,减少锁竞争次数。
int Take(std::deque<T> &tque, const size_t index)
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            bool waitret = m_notEmpty.wait_for(
                locker,
                std::chrono::milliseconds(m_waitTime),
                [this, index]() -> bool
                {
                    return m_needStop || !IsEmpty(index);
                });
            if (m_needStop)
            {
                return 1;
            }
            if (!waitret)//(超时/空)
            {
                return 2;
            }
            //移动资源,对象还在
            tque = std::move(m_taskQueues[index]);
            m_notFull.notify_all();
            return 0;
        }
        int Take(T &task, const size_t index)
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            // while (!m_needStop && IsEmpty())
            // {
            //     m_notEmpty.wait(locker);
            // }
            bool waitret = m_notEmpty.wait_for(
                locker,
                std::chrono::milliseconds(m_waitTime),
                [this, index]() -> bool
                {
                    return m_needStop || !IsEmpty(index);
                });
            if (m_needStop)
            {
                return 1;
            }
            if (!waitret)
            {
                return 2;
            }
            task = m_taskQueues[index].front();
            m_taskQueues[index].pop_front();
            m_notFull.notify_all();
            return 0;
        }
优雅关闭(WaitStop
  • 等待所有任务处理完毕后,设置停止标志,唤醒所有等待线程。
void WaitStop()
        {
            std::unique_lock<std::mutex> locker(m_mutex);
            while (GetTotalTaskSize() != 0)
            {
                m_notFull.wait_for(locker, std::chrono::seconds(m_waitTime));
            }
            // 设置停止标志
            m_needStop = true;
            // 唤醒所有等待的线程
            m_notEmpty.notify_all();
            m_notFull.notify_all();
        }

四、工作窃取线程池 WorkStealingThreadPool

这是线程池的核心,它负责管理工作线程、调度任务和实现窃取逻辑。

4.1 核心成员变量

class WorkStealingThreadPool
{
public:
    using Task = std::function<void(void)>;
private:
    size_t m_numThreads;                           // 线程数量
    tulun::WSyncQueue<Task> m_queue;               // 多桶任务队列
    std::vector<std::shared_ptr<std::thread>> m_threadgroup; // 线程组
    std::atomic<bool> m_running;                   // 运行标志
    std::once_flag m_flag;                         // 保证Stop只调用一次
    // ...
};

4.2 线程启动与任务分配(Start

  • 初始化时创建指定数量的工作线程,并为每个线程分配一个独立的任务队列。

4.3 核心窃取逻辑(RunInThread

这是工作窃取的灵魂,每个线程都在这个循环中执行任务:

void RunInThread(const size_t index)
        {
            while (m_running)
            {
                std::deque<Task> taskque;
                // 1. 优先从自己的队列获取任务
                m_queue.PrintTaskInfo();
                if (m_queue.Take(taskque, index) == 0)
                {
                    for (auto &task : taskque)
                    {
                        if (task)
                        {
                            task();
                        }
                    }
                }
                else
                {
                    // 2. 本地队列空了,尝试从其他线程窃取任务
                    size_t i = threadIndex();
                    if (i != index && m_queue.Take(taskque, i) == 0)
                    {
                        for (auto &task : taskque)
                        {
                            if (task)
                            {
                                task();
                            }
                        }
                    }
                }
            }
        }
  • 本地执行:优先处理自己队列里的任务,无需竞争。
  • 窃取执行:本地队列为空时,随机选择其他线程的队列进行窃取,实现负载均衡。

4.4 任务提交接口

支持两种任务提交方式:

  1. 普通任务(无返回值)
    void AddTask(Task &&task)
            {
                if (m_queue.Put(std::move(task), threadIndex()) != 0)
                {
                    LOG_ERROR << "Add task run task";
                    task();// 自己在主线程执行(降级策略)
                }
            }
    void AddTask(const Task &task)
            {
                if (m_queue.Put(task, threadIndex()) != 0)
                {
                    LOG_ERROR << "Add task run task";
                    task();
                }
            }
  2. 异步任务(带返回值)利用 std::packaged_taskstd::future 实现异步结果获取:
    template <class Func, class... Args>
    auto submit(Func &&func, Args &&...args)
            {
                using RetType = decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...));
                auto task = std::make_shared<std::packaged_task<RetType()>>(
                    [func = std::forward<Func>(func), ... args = std::forward<Args>(args)]() mutable
                    {
                        return std::invoke(func, args...);
                    });
    
                std::future<RetType> result = task->get_future();
    
                if (m_queue.Put([task]() -> void
                                { (*task)(); }, threadIndex()) != 0)
                {
                    LOG_ERROR << "Add task run task";
                    (*task)();
                }
    
                return result;
            }

4.5 优雅关闭(StopThreadGroup

  • 调用 m_queue.WaitStop() 等待所有任务处理完毕。
  • 设置 m_running = false,让所有线程退出循环。
  • join 所有工作线程,确保它们安全退出。

4.6 WorkStealingThreadPool.hpp


#include "Logger.hpp"
#include "SyncQueue_2.hpp"

#include <functional>
#include <future>
#include <memory>
#include <deque>
#include <atomic>
using namespace std;

//防止头文件被重复包含(重复编译),避免程序报错、崩溃!
#ifndef WORKSTEALINGPOOL_HPP
#define WORKSTEALINGPOOL_HPP

namespace tulun
{
    class WorkStealingThreadPool
    {
    public:
        using Task = std::function<void(void)>;// 任务类型:无参无返回值函数

    private:
        size_t m_numThreads;// 线程数量
        tulun::WSyncQueue<Task> m_queue;// 多桶任务队列(N个线程对应N个桶)
        std::vector<std::shared_ptr<std::thread>> m_threadgroup;// 线程组
        std::atomic<bool> m_running;
        std::once_flag m_flag;// 保证Stop只调用一次

        size_t threadIndex()
        {
            static size_t num = 0;
            return num++ % m_numThreads;
        }
        void Start(int numthreads)
        {
            m_running = true;
            for (int i = 0; i < numthreads; ++i)
            {
                std::shared_ptr<std::thread> tha = std::make_shared<std::thread>(
                    &WorkStealingThreadPool::RunInThread, this, i);
                m_threadgroup.push_back(std::move(tha));
            }
        }

        void RunInThread(const size_t index)
        {
            while (m_running)
            {
                std::deque<Task> taskque;
                // 1. 优先从自己的队列获取任务
                m_queue.PrintTaskInfo();
                if (m_queue.Take(taskque, index) == 0)
                {
                    for (auto &task : taskque)
                    {
                        if (task)
                        {
                            task();
                        }
                    }
                }
                else
                {
                    // 2. 本地队列空了,尝试从其他线程窃取任务
                    size_t i = threadIndex();
                    if (i != index && m_queue.Take(taskque, i) == 0)
                    {
                        for (auto &task : taskque)
                        {
                            if (task)
                            {
                                task();
                            }
                        }
                    }
                }
            }
        }

        void StopThreadGroup()
        {
            m_queue.WaitStop();
            m_running = false;
            //遍历所有线程 → 等待它们全部执行完 → 再关闭线程池
            for (auto &tha:m_threadgroup)
            {
                //auto &tha:代表每一个线程对象,tha 是 shared_ptr
                //一个线程只有 正在运行 / 已结束但未 join 才是 joinable
                if (tha && tha->joinable())
                {
                    tha->join();//阻塞等待这个线程执行完毕!
                }
            }
            m_threadgroup.clear();
        }

    public:
        WorkStealingThreadPool(const size_t qusize = 500, const size_t numthread = 8)
            : m_numThreads(numthread),
              m_queue(numthread, qusize),
              m_running(false)
        {
            Start(m_numThreads);
        }
        ~WorkStealingThreadPool()
        {
            if (m_running)
            {
                Stop();
            }
        }

        void Stop()
        {
            // std::call_once(m_flag,[this]{StopThreadGroup();});

            std::call_once(m_flag, std::bind(&WorkStealingThreadPool::StopThreadGroup, this));
        }

        void AddTask(Task &&task)
        {
            if (m_queue.Put(std::move(task), threadIndex()) != 0)
            {
                LOG_ERROR << "Add task run task";
                task();// 自己在主线程执行(降级策略)
            }
        }
        void AddTask(const Task &task)
        {
            if (m_queue.Put(task, threadIndex()) != 0)
            {
                LOG_ERROR << "Add task run task";
                task();
            }
        }

        template <class Func, class... Args>
        auto submit(Func &&func, Args &&...args)
        {
            using RetType = decltype(std::invoke(std::forward<Func>(func), std::forward<Args>(args)...));
            auto task = std::make_shared<std::packaged_task<RetType()>>(
                [func = std::forward<Func>(func), ... args = std::forward<Args>(args)]() mutable
                {
                    return std::invoke(func, args...);
                });

            std::future<RetType> result = task->get_future();

            if (m_queue.Put([task]() -> void
                            { (*task)(); }, threadIndex()) != 0)
            {
                LOG_ERROR << "Add task run task";
                (*task)();
            }

            return result;
        }
    };
} // namespace tulun

#endif

五、性能测试与适用场景

5.1 性能测试

test_funb 函数通过随机数生成、排序和文件保存三个阶段,对比了工作窃取线程池的性能:

#include <random>
#include <iostream>

using namespace std;

#include "Timestamp.hpp"
#include "Logger.hpp"
#include "WorkStealingThreadPool.hpp"

static const size_t row = 10000;
static const size_t col = 100000;

std::vector<std::vector<int>> iveca, ivecb, ivecc, ivecd;

void RandInit(std::vector<int> &ivec)
{
    std::random_device rd;                             // 真随机种子源
    std::mt19937 gen(rd());                            // 随机数引擎
    std::uniform_int_distribution<int> dis(0, 100000); // 分布范围 [1, 100]

    // int random_num = dis(gen);
    // ivec.reserve(col);
    for (int i = 0; i < col; ++i)
    {
        // ivec.push_back(rand() % 10000);
        ivec.push_back(dis(gen));
    }
}
void SortVec(std::vector<int> &ivec)
{
    std::sort(ivec.begin(), ivec.end());
}
void SaveFile(std::vector<int> &ivec, const std::string &filename)
{
    FILE *pf = fopen(filename.c_str(), "a");
    if (nullptr == pf)
    {
        LOG_FATAL << "fopen fail \n";
        exit(EXIT_FAILURE);
    }
    for (int i = 0; i < col; ++i)
    {
        fprintf(pf, "%d ", ivec[i]);
    }
    fprintf(pf, "\n---------------------\n");
    fclose(pf);
    pf = nullptr;
}

void test_funb()
{
    // 1. 随机数生成(CPU密集型)
    // 2. 排序(CPU密集型)
    // 3. 文件保存(I/O密集型)
    srand(time(nullptr));
    cout << "8 线程测试...." << endl;

    tulun::Timestamp start, end;
    start = tulun::Timestamp::Now();
    {
        ivecb.resize(row);
        tulun::WorkStealingThreadPool mypool(500,8);
        for (int i = 0; i < row; ++i)
        {
            ivecb[i].reserve(col);
            // RandInit(ivecb[i]);
            mypool.AddTask(std::bind(RandInit, std::ref(ivecb[i])));
        }
    }

    end = tulun::Timestamp::Now();
    cout << "RandInit: " << (tulun::diffMicro(end, start)) / 1000000 << "." << (tulun::diffMicro(end, start)) % 1000000 << endl;

    start = tulun::Timestamp::Now();
    {
        tulun::WorkStealingThreadPool mypool(500,8);
        for (int i = 0; i < row; ++i)
        {
            // SortVec(ivecb[i]);
            mypool.AddTask(std::bind(SortVec, std::ref(ivecb[i])));
        }
    }
    end = tulun::Timestamp::Now();
    cout << "SortVec : " << (tulun::diffMicro(end, start)) / 1000000 << "." << (tulun::diffMicro(end, start)) % 1000000 << endl;

    start = tulun::Timestamp::Now();
    {
        tulun::WorkStealingThreadPool mypool(500,8);
        for (int i = 0; i < row; ++i)
        {
            // SaveFile(ivecb[i], "iveca.txt");
            mypool.AddTask(std::bind(SaveFile, std::ref(ivecb[i]), std::string("ivecb.txt")));
        }
    }

    end = tulun::Timestamp::Now();
    cout << "SaveFile: " << (tulun::diffMicro(end, start)) / 1000000 << "." << (tulun::diffMicro(end, start)) % 1000000 << endl;
    return;
}

int main()
{
    test_funb();
    return 0;
}

工作窃取线程池在 CPU 密集型任务中表现尤为出色,因为它能有效平衡负载,避免线程空转。

5.2 适用场景

结合工作窃取算法的特性,它最适合以下场景:

  1. 任务分解型应用:如并行图像处理、数据处理、排序算法,一个大任务可分解为多个独立子任务。
  2. 递归型任务:如斐波那契计算、归并排序,任务执行过程中会动态生成新的子任务。
  3. 高吞吐量 CPU 密集型任务:任务执行时间相近、无 I/O 阻塞,需要充分利用多核 CPU。

5.3 注意事项

  • 工作窃取算法对 I/O 阻塞型任务 效果不佳,因为线程被阻塞时,其他线程无法窃取其任务。
  • 任务量过小时,线程窃取的开销可能大于收益,需根据业务场景选择合适的线程池类型。

六、总结:工作窃取线程池的优势与核心要点

特性 说明
核心优势 减少锁竞争,实现负载均衡,提高多核利用率
关键设计 每个线程独立队列,优先本地执行,空闲时主动窃取
适用场景 CPU 密集型、可分解、高并发任务
实现要点 多桶同步队列、批量窃取任务、优雅关闭机制

更多推荐