一、多路归并算法的由来

        假定现在有一包含大量整数的文本文件存放于磁盘中,其文件大小为10GB,而本机内存只有4GB。此时若我们要对该文件中的所有整数进行升序排序,肯定不能直接将文件中的所有数据一次性读入内存中,再使用快速、归并等排序算法对这么大规模的整数进行排序。

        好像陷入了难题? 我们不妨换一个思路,为何不将10GB大文件拆分为10个1GB的小文件呢? 逐个对10个文件进行排序后,再将其写入磁盘中,此时就得到了10份已排序后的临时文件

        每一份文件都是一个升序序列,这时问题就转换为如何合并这10路升序序列为1路升序序列。正因为待合并的数据路数比较多,所以才有了多路归并这一说法。

        还是有些抽象?那不妨举个具体的例子来瞅瞅,假定有nums1、num2、nums3三路升序序列,先打算将他们合并为一路升序序列:

nums1 = [1, 2, 3]

nums2 = [2, 4, 6]

nums3 = [3, 5 ,8] =》 合并结果: res = [1, 2, 2, 3, 3, 4, 5, 6, 8]

二、多路归并算法的最简编程模型

        现在我们将刚才提到的合并k路升序序列问题,转化为具体的C++代码,也即是建立最简单的多路归并算法编程模型。以k = 3,即3路为例:

        为了便于理解,我们将三路序列存储为矩阵的形式(如上图),横坐标x的取值范围为0,1,2代表有三列,纵坐标的取值范围也为0,1,2代表有三行。第一行(y=0)是nums1,第二行(y=1)是nums2,第三行(y=2)是nums3。

多路归并算法的基本思想如下:

1) 首先建立一个小顶堆;

2) 将每一路的最小元素(即第1列元素)都加入小顶堆中,此时堆顶就是k路中全局的最小值;

3) 将堆顶元素弹出,并将堆顶元素所在数组的下一个元素加入堆中。

4) 重复第2)和第3)步,直至每一路数据都读取结束。

#include <iostream>
#include <vector>
#include <queue>

int main()
{
    using namespace std;
    using VI = vector<int>;

    vector<VI> src = { {1, 2, 3}, {2, 4, 6}, {3, 5, 8} };
    priority_queue<VI, vector<VI>, greater<VI>> pq;
    VI vc;

    //1.先将每路的最小值(第一列元素)放入堆中, 以获得全局的最小值
    for(int i = 0; i < src.size(); ++i)
    {
        vc = {src[i][0], i, 0};  //值, 纵坐标, 横坐标
        pq.emplace(vc);
    }

    VI ans;
    while( !pq.empty() )
    {
        vc = pq.top();  //弹出堆顶元素
        pq.pop();

        ans.emplace_back(vc[0]);
        //将堆顶元素所在数组的下一个元素放入堆中
        if(vc[2] < src[vc[1]].size() - 1) {
            vc = {src[ vc[1] ][ vc[2] + 1 ], vc[1], vc[2] + 1};
            pq.emplace(vc);
        }
    }

    for(auto &x : ans)
    {
        cout << x << " ";
    }

    return 0;
}

​​​​​​​

【题外话】

  Q:这里提到的多路归并算法与常说的归并排序算法有何联系?存在单路归并吗?

  A:若读者能回想下归并排序的过程,就知道归并排序实际上是两路数据进行归并,即写出递推公式:sort(nums) = merge[ sort(nums[0: mid]) + sort(nums[mid + 1, right]) ] ,这不正是对两路数据进行合并嘛!

三、实战一道LeetCode算法题

 1. LeetCode #23 合并k个升序链表 (Hard)

【题目描述】这里直接照搬了LeetCode官网的题目说明:

               

(大致想法) 

        单从题目示例给出的形式来看,是不是与我们刚才提到的多路归并算法如出一辙?只不过这里给出的不是数组,而是链表,但换汤不换药,本质思想其实还是一样的。我们套用上面的多路归并算法模型来解决下本题。

(代码实现)

class Solution {
public:
    ListNode* mergeKLists(vector<ListNode*>& lists) {
        if(lists.size() == 0) return nullptr;

        auto cmp = [] (const ListNode* node1, const ListNode* node2) -> bool{
            return node1->val > node2->val;
        };
        //小顶堆
        priority_queue<ListNode*, vector<ListNode*>, decltype(cmp)> pq(cmp);
        
        //1.将每个链表的首结点入堆, 以获取全局的最小值
        for(auto &node: lists)
        {
            if(node) pq.emplace(node);  //空链表不入堆
        }

        ListNode dummyNode, *r = &dummyNode; //r始终指向已排序部分的尾部
        ListNode *minNode;  //用于保存堆顶元素

        while(!pq.empty())
        {
            minNode = pq.top();  //弹出堆顶元素
            pq.pop();
            
            r->next = minNode;  //将堆顶元素挂载至结果集中
            r = r->next;  
            //将堆顶元素所在链表的下一个结点加入堆中
            if(minNode->next) pq.emplace(minNode->next);
        }

        return dummyNode.next;
    }
};

时间复杂度:O(knlogk),logk为堆的高度,kn为添加到堆中的总结点数。

空间复杂度: O(k),小顶堆的空间。

【题外话】

  Q1:如果将题目改为“获取k个升序链表合并后的前k个最小值”,此时又该如何解决问题?需要将k个链表的所有数据都遍历完后,才能得出这k个最小值吗?

  Q2:感兴趣的读者,可以继续带着多路归并的思想去学习LeetCode #373 查找和最小的k对数字。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐