整体框架分析

在这里插入图片描述

  • 对SGI-STL30版本源代码的框架进行分析,可以看到rb_tree⽤了⼀个巧妙的泛型思想实现,rb_tree是实现key的搜索场景,还是key/value的搜索场景不是直接写死的,而是由第二个模板参数Value决定红黑树结点中存储的真实的数据的类型。
  • 那么,set实例化rb_tree时第二个模板参数给的是key,map实例化rb_tree时第二个模板参数给的是pair<const key,T>,这样⼀颗红黑树既可以实现key搜索场景的set,也可以实现key/value搜索场景的map。
  • rb_tree第二个模板参数Value已经控制了红黑树结点中存储的数据类型,为什么还要传第⼀个模板参数Key呢?对于map和set,find/erase时的函数参数都是Key,所以第⼀个模板参数是传给find/erase等函数做形参的类型的。对于set而言两个参数是⼀样的,但是对于map而言就完全不⼀样了,map insert的是pair对象,但是find和erase的是Key对象。

模拟实现map和set思路

实现复用红黑树的框架

  • 参考源码框架,map和set复用之前我实现的红黑树。
  • 这里相比源码调整⼀下,key参数就用K,value参数就用V,红黑树中的数据类型,我们使用T。
  • 在insert内部进行插入逻辑比较时,因为不知道比较的是K,还是pair<K,V>,且pair的默认支持的是key和value⼀起参与比较,这和我们需要的任何时候只比较key的需求不符,所以参考源码中的思路,我们在map和set层分别实现⼀个MapKeyOfT和SetKeyOfT的仿函数传给RBTree的KeyOfT,然后RBTree中通过KeyOfT仿函数取出T类型对象中的key,再进行比较。

下面通过代码展示实现细节
myset.h

#include"RBTree.h"
namespace highcool
{
	template<class K>
	class set
	{
		struct SetKeyOfT
		{
			const K& operator()(const K& key)
			{
				return key;
			}
		};
	public:
		bool insert(const K& key)
		{
			return _t.Insert(key);
		}
	private:
		RBTree<K, K, SetKeyOfT> _t;
	};
}

mymap.h

#include"RBTree.h"
namespace highcool
{
	template<class K, class V>
	class map
	{
		struct MapKeyOfT
		{
			const K& operator()(const pair<K, V>& kv)
			{
				return kv.first;
			}
		};
	public:
		bool insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv);
		}
	private:
		RBTree<K, pair<K, V>, MapKeyOfT> _t;
	};
}

RBTree.h

#pragma once
namespace highcool
{
	enum Colour
	{
		RED,
		BLACK
	};

	template<class T>
	struct RBTreeNode
	{
		// 这里更新控制平衡也要加入parent指针
		T _date;
		RBTreeNode<T>* _left;
		RBTreeNode<T>* _right;
		RBTreeNode<T>* _parent;
		Colour _col;

		RBTreeNode(const T& date)
			:_date(date)
			, _left(nullptr)
			, _right(nullptr)
			, _parent(nullptr)
		{
		}
	};

	template<class K, class T, class KeyOfT>
	class RBTree
	{
		typedef RBTreeNode<T> Node;
	public:
		bool Insert(const T& date)
		{
			if (_root == nullptr)
			{
				_root = new Node(date);
				_root->_col = BLACK;

				return true;
			}
			KeyOfT kot;
			Node* parent = nullptr;
			Node* cur = _root;
			while (cur)
			{
				if (kot(cur->_date)<kot(date))
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (kot(cur->_date) > kot(date))
				{
					parent = cur;
					cur = cur->_left;
				}
				else
				{
					return false;
				}
			}

			cur = new Node(date);
			cur->_col = RED;
			if (kot(parent->_date) < kot(date))
			{
				parent->_right = cur;
			}
			else
			{
				parent->_left = cur;
			}
			// 链接父亲
			cur->_parent = parent;

			// 父亲是红色,出现连续的红色节点,需要处理
			while (parent && parent->_col == RED)
			{
				Node* grandfather = parent->_parent;
				if (parent == grandfather->_left)
				{
					//   g
					// p   u
					Node* uncle = grandfather->_right;
					if (uncle && uncle->_col == RED)
					{
						// 变色
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						// 继续往上处理
						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						if (cur == parent->_left)
						{
							//     g
							//   p    u
							// c
							RotateR(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							//      g
							//   p    u
							//     c
							RotateL(parent);
							RotateR(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
				else
				{
					//   g
					// u   p
					Node* uncle = grandfather->_left;
					// 叔叔存在且为红,-》变色即可
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						// 继续往上处理
						cur = grandfather;
						parent = cur->_parent;
					}
					else // 叔叔不存在,或者存在且为黑
					{
						// 情况二:叔叔不存在或者存在且为黑
						// 旋转+变色
						//   g
						// u   p
						//       c
						if (cur == parent->_right)
						{
							RotateL(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							RotateR(parent);
							RotateL(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
			}

			_root->_col = BLACK;

			return true;
		}

		void RotateR(Node* parent)
		{
			Node* subL = parent->_left;
			Node* subLR = subL->_right;

			parent->_left = subLR;
			if (subLR)
				subLR->_parent = parent;

			Node* pParent = parent->_parent;

			subL->_right = parent;
			parent->_parent = subL;

			if (parent == _root)
			{
				_root = subL;
				subL->_parent = nullptr;
			}
			else
			{
				if (pParent->_left == parent)
				{
					pParent->_left = subL;
				}
				else
				{
					pParent->_right = subL;
				}

				subL->_parent = pParent;
			}
		}

		void RotateL(Node* parent)
		{
			Node* subR = parent->_right;
			Node* subRL = subR->_left;
			parent->_right = subRL;
			if (subRL)
				subRL->_parent = parent;

			Node* parentParent = parent->_parent;
			subR->_left = parent;
			parent->_parent = subR;
			if (parentParent == nullptr)
			{
				_root = subR;
				subR->_parent = nullptr;
			}
			else
			{
				if (parent == parentParent->_left)
				{
					parentParent->_left = subR;
				}
				else
				{
					parentParent->_right = subR;
				}
				subR->_parent = parentParent;
			}
		}

		void InOrder()
		{
			_InOrder(_root);
			cout << endl;
		}

		int Height()
		{
			return _Height(_root);
		}

		int Size()
		{
			return _Size(_root);
		}

		Node* Find(const K& key)
		{
			Node* cur = _root;
			while (cur)
			{
				if (cur->_kv.first < key)
				{
					cur = cur->_right;
				}
				else if (cur->_kv.first > key)
				{
					cur = cur->_left;
				}
				else
				{
					return cur;
				}
			}

			return nullptr;
		}

		bool IsBalance()
		{
			if (_root == nullptr)
				return true;

			if (_root->_col == RED)
				return false;

			// 参考值
			int refNum = 0;
			Node* cur = _root;
			while (cur)
			{
				if (cur->_col == BLACK)
				{
					++refNum;
				}
				cur = cur->_left;
			}

			return Check(_root, 0, refNum);
		}

	private:

		bool Check(Node* root, int blackNum, const int refNum)
		{
			if (root == nullptr)
			{
				// 前序遍历走到空时,意味着一条路径走完了
				//cout << blackNum << endl;
				if (refNum != blackNum)
				{
					cout << "存在黑色结点的数量不相等的路径" << endl;
					return false;
				}
				return true;
			}

			// 检查孩子不太方便,因为孩子有两个,且不一定存在,反过来检查父亲就方便多了
			if (root->_col == RED && root->_parent->_col == RED)
			{
				cout << root->_kv.first << "存在连续的红色结点" << endl;
				return false;
			}

			if (root->_col == BLACK)
			{
				blackNum++;
			}

			return Check(root->_left, blackNum, refNum)
				&& Check(root->_right, blackNum, refNum);
		}

		void _InOrder(Node* root)
		{
			if (root == nullptr)
			{
				return;
			}

			_InOrder(root->_left);
			cout << root->_kv.first << ":" << root->_kv.second << endl;
			_InOrder(root->_right);
		}

		int _Height(Node* root)
		{
			if (root == nullptr)
				return 0;
			int leftHeight = _Height(root->_left);
			int rightHeight = _Height(root->_right);
			return leftHeight > rightHeight ? leftHeight + 1 : rightHeight + 1;
		}

		int _Size(Node* root)
		{
			if (root == nullptr)
				return 0;

			return _Size(root->_left) + _Size(root->_right) + 1;
		}

	private:
		Node* _root = nullptr;
	};
}

实现支持iterator

  • iterator实现的大框架跟list的iterator思路是⼀致的,用⼀个类型封装结点的指针,再通过重载运算符,实现让迭代器像指针⼀样访问的行为。

  • begin()会返回中序第⼀个结点的iterator迭代器,因为map和set的迭代器走的是中序遍历,左子树->根结点->右子树

  • 这里的难点是operator++和operator–的实现。迭代器++的核心逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下⼀个结点。
    1.迭代器++时,如果it指向的结点的右子树不为空,代表当前结点已经访问完了,要访问下⼀个结点是右子树的中序第⼀个,即右子树的最左结点
    2.迭代器++时,如果it指向的结点的右子树空,代表当前结点已经访问完了且当前结点所在的子树也访问完了,要访问的下⼀个结点在当前结点的祖先里面,所以要沿着当前结点到根的祖先路径向上找
    (1)如果当前结点是父亲的左,那么下⼀个访问的结点就是当前结点的父亲
    (2)如果当前结点是父亲的右,当前当前结点所在的子树访问完了,当前结点所在父亲的子树也访问完了,那么下⼀个访问的需要继续往根的祖先中去找,直到找到孩子是父亲左的那个父亲节点,就是中序要问题的下⼀个结点

  • end()如何表示呢?根没有父亲,没有找到孩子是父亲左的那个祖先,已经遍历到根了,父亲为空,那我们就把it中的结点指针置为nullptr,我们用nullptr去充当end。

  • 注意–end()判断到结点时空,需要特殊处理⼀下,让迭代器结点指向最右结点。

  • 迭代器–的实现跟++的思路完全类似,逻辑正好反过来即可,不赘述了。

  • 还有一些细节要点,set的iterator不支持修改,我们把set的第二个模板参数改成const K即可,RBTree <K,const K, SetKeyOfT> _t;map的iterator不支持修改key但是可以修改value,我们把map的第二个模板参数pair的第⼀个参数改成const K即可, RBTree<K, pair<const K, V>, MapKeyOfT> _t。

实现map支持[]

map要支持[]主要需要修改insert返回值支持,修改RBtree中的insert返回值为
pair<Iterator, bool> Insert(const T& data)

highcool::map和highcool::set代码实现

  • myset.h
#pragma once
#include"RBTree.h"
namespace highcool
{
	template<class K>
	class set
	{
		struct SetKeyOfT
		{
			const K& operator()(const K& key)
			{
				return key;
			}
		};
	public:
		typedef typename RBTree<K,const K,SetKeyOfT>::Iterator iterator;
		typedef typename RBTree<K, const K, SetKeyOfT>::ConstIterator const_iterator;
		iterator begin()
		{
			return _t.Begin();
		}
		iterator end()
		{
			return _t.End();
		}
		const_iterator begin() const
		{
			return _t.Begin();
		}
		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator,bool> insert(const K& key)
		{
			return _t.Insert(key);
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}
	private:
		RBTree<K, const K, SetKeyOfT> _t;
	};
}
  • mymap.h
#pragma once
#include"RBTree.h"
namespace highcool
{
	template<class K, class V>
	class map
	{
		struct MapKeyOfT
		{
			const K& operator()(const pair<K, V>& kv)
			{
				return kv.first;
			}
		};
	public:
		typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::Iterator iterator;
		typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::ConstIterator const_iterator;
		iterator begin()
		{
			return _t.Begin();
		}
		iterator end()
		{
			return _t.End();
		}
		const_iterator begin() const
		{
			return _t.Begin();
		}
		const_iterator end() const
		{
			return _t.End();
		}
		pair<iterator, bool> insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv);
		}
		iterator find(const K& key)
		{
			return _t.Find(key);
		}

		V& operator[](const K& key)
		{
			pair<iterator, bool> ret = insert({ key,V() });
			return ret.first->second;
		}
	private:
		RBTree<K, pair<const K, V>, MapKeyOfT> _t;
	};
}
  • RBTree.h
#pragma once
namespace highcool
{
	enum Colour
	{
		RED,
		BLACK
	};

	template<class T>
	struct RBTreeNode
	{
		// 这里更新控制平衡也要加入parent指针
		T _date;
		RBTreeNode<T>* _left;
		RBTreeNode<T>* _right;
		RBTreeNode<T>* _parent;
		Colour _col;

		RBTreeNode(const T& date)
			:_date(date)
			, _left(nullptr)
			, _right(nullptr)
			, _parent(nullptr)
		{
		}
	};

	template<class T,class Ref,class Ptr>
	struct RBTreeIterator
	{
		typedef RBTreeNode<T> Node;
		typedef RBTreeIterator<T, Ref, Ptr> Self;
		Node* _node;
		Node* _root;
		RBTreeIterator(Node* node,Node* root)
			:_node(node)
			,_root(root)
		{ }
		Self operator++()
		{
			if (_node->_right)
			{
				Node* cur = _node->_right;
				while (cur->_left)
				{
					cur = cur->_left;
				}
				_node = cur;
			}
			else
			{
				Node* cur = _node;
				Node* parent = cur->_parent;
				while (parent&&cur == parent->_right)
				{
					cur = parent;
					parent = cur->_parent;
				}
				_node = parent;
			}
			return *this;
		}

		Self& operator--()
		{
			if (_node == nullptr)
			{
				Node* cur = _root;
				while (cur && cur->_right)
				{
					cur = cur->_right;
				}
				_node = cur;
			}
			else if (_node->_left)
			{
				Node* cur = _node->_left;
				while (cur->_right)
				{
					cur = cur->_right;
				}
				_node = cur;
			}
			else
			{
				Node* cur = _node;
				Node* parent = cur->_parent;
				while (parent && cur == parent->_left)
				{
					cur = parent;
					parent = cur->_parent;
				}
				_node = parent;
			}
			return *this;
		}

		Ref operator*()
		{
			return _node->_date;
		}
		Ptr operator->()
		{
			return &(_node->_date);
		}
		bool operator==(const Self& s) const
		{
			return _node == s._node;
		}
		bool operator!=(const Self& s) const
		{
			return _node != s._node;
		}
	};


	template<class K, class T, class KeyOfT>
	class RBTree
	{
		typedef RBTreeNode<T> Node;
	public:
		typedef RBTreeIterator<T, T&, T*> Iterator;
		typedef RBTreeIterator<T, const T&, const T*> ConstIterator;

		Iterator Begin()
		{
			Node* cur = _root;
			while (cur && cur->_left)
			{
				cur = cur->_left;
			}
			return Iterator(cur, _root);
		}
		Iterator End()
		{
			return Iterator(nullptr, _root);
		}

		ConstIterator Begin() const 
		{
			Node* cur = _root;
			while (cur && cur->_left)
			{
				cur = cur->_left;
			}
			return ConstIterator(cur, _root);
		}
		ConstIterator End() const
		{
			return ConstIterator(nullptr, _root);
		}

		RBTree() = default;

		~RBTree()
		{
			Destroy(_root);
			_root = nullptr;
		}

		void Destroy(Node* root)
		{
			if (root == nullptr)
				return;
			Destroy(root->_left);
			Destroy(root->_right);
			delete root;
		}

		pair<Iterator,bool> Insert(const T& date)
		{
			if (_root == nullptr)
			{
				_root = new Node(date);
				_root->_col = BLACK;

				return {Iterator(_root,_root),true};
			}
			KeyOfT kot;
			Node* parent = nullptr;
			Node* cur = _root;
			while (cur)
			{
				if (kot(cur->_date)<kot(date))
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (kot(cur->_date) > kot(date))
				{
					parent = cur;
					cur = cur->_left;
				}
				else
				{
					return { Iterator(cur,_root),false };
				}
			}

			cur = new Node(date);
			Node* newnode = cur;
			cur->_col = RED;
			if (kot(parent->_date) < kot(date))
			{
				parent->_right = cur;
			}
			else
			{
				parent->_left = cur;
			}
			// 链接父亲
			cur->_parent = parent;

			// 父亲是红色,出现连续的红色节点,需要处理
			while (parent && parent->_col == RED)
			{
				Node* grandfather = parent->_parent;
				if (parent == grandfather->_left)
				{
					//   g
					// p   u
					Node* uncle = grandfather->_right;
					if (uncle && uncle->_col == RED)
					{
						// 变色
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						// 继续往上处理
						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						if (cur == parent->_left)
						{
							//     g
							//   p    u
							// c
							RotateR(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							//      g
							//   p    u
							//     c
							RotateL(parent);
							RotateR(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
				else
				{
					//   g
					// u   p
					Node* uncle = grandfather->_left;
					// 叔叔存在且为红,-》变色即可
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						// 继续往上处理
						cur = grandfather;
						parent = cur->_parent;
					}
					else // 叔叔不存在,或者存在且为黑
					{
						// 情况二:叔叔不存在或者存在且为黑
						// 旋转+变色
						//   g
						// u   p
						//       c
						if (cur == parent->_right)
						{
							RotateL(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							RotateR(parent);
							RotateL(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
			}

			_root->_col = BLACK;

			return { Iterator(newnode,_root),true };
		}

		void RotateR(Node* parent)
		{
			Node* subL = parent->_left;
			Node* subLR = subL->_right;

			parent->_left = subLR;
			if (subLR)
				subLR->_parent = parent;

			Node* pParent = parent->_parent;

			subL->_right = parent;
			parent->_parent = subL;

			if (parent == _root)
			{
				_root = subL;
				subL->_parent = nullptr;
			}
			else
			{
				if (pParent->_left == parent)
				{
					pParent->_left = subL;
				}
				else
				{
					pParent->_right = subL;
				}

				subL->_parent = pParent;
			}
		}

		void RotateL(Node* parent)
		{
			Node* subR = parent->_right;
			Node* subRL = subR->_left;
			parent->_right = subRL;
			if (subRL)
				subRL->_parent = parent;

			Node* parentParent = parent->_parent;
			subR->_left = parent;
			parent->_parent = subR;
			if (parentParent == nullptr)
			{
				_root = subR;
				subR->_parent = nullptr;
			}
			else
			{
				if (parent == parentParent->_left)
				{
					parentParent->_left = subR;
				}
				else
				{
					parentParent->_right = subR;
				}
				subR->_parent = parentParent;
			}
		}

		void InOrder()
		{
			_InOrder(_root);
			cout << endl;
		}

		int Height()
		{
			return _Height(_root);
		}

		int Size()
		{
			return _Size(_root);
		}

		Iterator Find(const K& key)
		{
			Node* cur = _root;
			while (cur)
			{
				if (cur->_kv.first < key)
				{
					cur = cur->_right;
				}
				else if (cur->_kv.first > key)
				{
					cur = cur->_left;
				}
				else
				{
					return Iterator(cur,_root);
				}
			}

			return End();
		}

		bool IsBalance()
		{
			if (_root == nullptr)
				return true;

			if (_root->_col == RED)
				return false;

			// 参考值
			int refNum = 0;
			Node* cur = _root;
			while (cur)
			{
				if (cur->_col == BLACK)
				{
					++refNum;
				}
				cur = cur->_left;
			}

			return Check(_root, 0, refNum);
		}

	private:

		bool Check(Node* root, int blackNum, const int refNum)
		{
			if (root == nullptr)
			{
				// 前序遍历走到空时,意味着一条路径走完了
				//cout << blackNum << endl;
				if (refNum != blackNum)
				{
					cout << "存在黑色结点的数量不相等的路径" << endl;
					return false;
				}
				return true;
			}

			// 检查孩子不太方便,因为孩子有两个,且不一定存在,反过来检查父亲就方便多了
			if (root->_col == RED && root->_parent->_col == RED)
			{
				cout << root->_kv.first << "存在连续的红色结点" << endl;
				return false;
			}

			if (root->_col == BLACK)
			{
				blackNum++;
			}

			return Check(root->_left, blackNum, refNum)
				&& Check(root->_right, blackNum, refNum);
		}

		void _InOrder(Node* root)
		{
			if (root == nullptr)
			{
				return;
			}

			_InOrder(root->_left);
			cout << root->_kv.first << ":" << root->_kv.second << endl;
			_InOrder(root->_right);
		}

		int _Height(Node* root)
		{
			if (root == nullptr)
				return 0;
			int leftHeight = _Height(root->_left);
			int rightHeight = _Height(root->_right);
			return leftHeight > rightHeight ? leftHeight + 1 : rightHeight + 1;
		}

		int _Size(Node* root)
		{
			if (root == nullptr)
				return 0;

			return _Size(root->_left) + _Size(root->_right) + 1;
		}

	private:
		Node* _root = nullptr;
	};
}

更多推荐