C++ 实现线程安全的map
STL容器不是线程安全的。比如对于vector,即使写方(生产者)是单线程写入,但是并发读的时候,由于潜在的内存重新申请和对象复制问题,会导致读方(消费者)的迭代器失效。实际表现也就是招致了core dump。另外一种情况,如果是多个写方,并发的push_back(),也会导致core dump。但可以通过固定vector的大小(调用resize())避免动态扩容(无push_back)来做到lo
STL容器不是线程安全的。比如对于vector,即使写方(生产者)是单线程写入,但是并发读的时候,由于潜在的内存重新申请和对象复制问题,会导致读方(消费者)的迭代器失效。实际表现也就是招致了core dump。另外一种情况,如果是多个写方,并发的push_back(),也会导致core dump。但可以通过固定vector的大小(调用resize())避免动态扩容(无push_back)来做到lock-free。
c++的map的并发操作也是不安全的,c++里边有红黑树实现的std::map和hash表unordered_map。在《C++并发编程实战》一书中的162页提供了一个细粒度锁的MAP数据结构,使用了 boost的shared_mutex (C++14已经支持,C++11没有),那上面的实现代码挺长的,这里给出个OpenHarmony源码实现的safe_map,代码精简,值得学习。
一.源码实现
源码位置:code-v3.0-LTS\OpenHarmony\utils\native\base\include\safe_map.h
/*
* Copyright (c) 2021 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef UTILS_BASE_SAFE_MAP_H
#define UTILS_BASE_SAFE_MAP_H
#include <map>
#include <mutex>
namespace OHOS {
template <typename K, typename V>
class SafeMap {
public:
SafeMap() {}
~SafeMap() {}
SafeMap(const SafeMap& rhs)
{
map_ = rhs.map_;
}
SafeMap& operator=(const SafeMap& rhs)
{
if (&rhs != this) {
map_ = rhs.map_;
}
return *this;
}
V& operator[](const K& key)
{
return map_[key];
}
// when multithread calling size() return a tmp status, some threads may insert just after size() call
int Size()
{
std::lock_guard<std::mutex> lock(mutex_);
return map_.size();
}
// when multithread calling Empty() return a tmp status, some threads may insert just after Empty() call
bool IsEmpty()
{
std::lock_guard<std::mutex> lock(mutex_);
return map_.empty();
}
bool Insert(const K& key, const V& value)
{
std::lock_guard<std::mutex> lock(mutex_);
auto ret = map_.insert(std::pair<K, V>(key, value));
return ret.second;
}
void EnsureInsert(const K& key, const V& value)
{
std::lock_guard<std::mutex> lock(mutex_);
auto ret = map_.insert(std::pair<K, V>(key, value));
// find key and cannot insert
if (!ret.second) {
map_.erase(ret.first);
map_.insert(std::pair<K, V>(key, value));
return;
}
return;
}
bool Find(const K& key, V& value)
{
bool ret = false;
std::lock_guard<std::mutex> lock(mutex_);
auto iter = map_.find(key);
if (iter != map_.end()) {
value = iter->second;
ret = true;
}
return ret;
}
bool FindOldAndSetNew(const K& key, V& oldValue, const V& newValue)
{
bool ret = false;
std::lock_guard<std::mutex> lock(mutex_);
if (map_.size() > 0) {
auto iter = map_.find(key);
if (iter != map_.end()) {
oldValue = iter->second;
map_.erase(iter);
map_.insert(std::pair<K, V>(key, newValue));
ret = true;
}
}
return ret;
}
void Erase(const K& key)
{
std::lock_guard<std::mutex> lock(mutex_);
map_.erase(key);
}
void Clear()
{
std::lock_guard<std::mutex> lock(mutex_);
map_.clear();
return;
}
private:
std::mutex mutex_;
std::map<K, V> map_;
};
} // namespace OHOS
#endif
二.源码欣赏
使用模板语法template <typename K, typename V>让这个map的实现更通用。这是c++模板泛型的强大之处,不用针对每个类型都实现一遍,复用性更强。且模板是在编译期检查的,也降低的出错的可能性。内部实现上,倒是没啥特别的,就是对相应的操作加了锁。锁使用的RAII模型的std::lock_guard写法,这种很常见也很常用。
自定义实现了几个常用的操作方法如Find、Erase和Clear,每个里面的操作都相应的加了锁。操作符重载实现了[]和赋值=操作。注意这两处的地方没有用锁,你知道为什么吗?如果多个线程只访问容器但不更改其结构,则不需要对容器进行同步。另外一个原因是,对于map,如果只是通过[]的方式修改而不是新插入,则多线程下也不会core dump。
三.单元测试
源码中同样有safe_map的单元测试,单元测试框架使用的是google的gtest。看来gtest还是很强大的,华为也选择使用了它。以下给出源码,可以熟悉下gtest单元测试的用法。
/*
* Copyright (c) 2021 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "safe_map.h"
#include <array>
#include <future>
#include <gtest/gtest.h>
#include <iostream>
#include <thread>
#include <iostream>
#include <chrono> // std::chrono::seconds
#include <iostream> // std::cout
#include <thread> // std::thread, std::this_thread::sleep_for
using namespace testing::ext;
using namespace OHOS;
using namespace std;
class UtilsSafeMap : public testing::Test {
};
/*
* @tc.name: testUtilsCopyAndAssign001
* @tc.desc: single thread test the normal feature insert and erase and EnsureInsert
*/
HWTEST_F(UtilsSafeMap, testUtilsCopyAndAssign001, TestSize.Level1)
{
SafeMap<string, int> demoData;
// insert new
demoData.Insert("A", 1);
ASSERT_FALSE(demoData.IsEmpty());
ASSERT_EQ(demoData.Size(), 1);
SafeMap<string, int> newdemo = demoData;
int tar = -1;
ASSERT_TRUE(newdemo.Find("A", tar));
ASSERT_EQ(1, tar);
tar = -1;
SafeMap<string, int> newdemo2;
newdemo2 = demoData;
ASSERT_TRUE(newdemo2.Find("A", tar));
ASSERT_EQ(1, tar);
}
/*
* @tc.name: testUtilsoperator001
* @tc.desc: SafeMap
*/
HWTEST_F(UtilsSafeMap, testUtilsoperator001, TestSize.Level1)
{
SafeMap<string, int> demoData;
// insert new
demoData.Insert("A", 1);
ASSERT_FALSE(demoData.IsEmpty());
ASSERT_EQ(demoData.Size(), 1);
ASSERT_EQ(demoData["A"], 1);
SafeMap<string, int> newdemo = demoData;
ASSERT_EQ(newdemo["A"], 1);
int tar = -1;
newdemo["B"] = 6;
ASSERT_TRUE(newdemo.Find("B", tar));
ASSERT_EQ(6, tar);
SafeMap<string, int> newdemo2;
newdemo2 = newdemo;
ASSERT_EQ(newdemo2["A"], 1);
}
/*
* @tc.name: testUtilsNormalFeatureInsert001
* @tc.desc: SafeMap
*/
HWTEST_F(UtilsSafeMap, testUtilsNormalFeatureInsert001, TestSize.Level1)
{
SafeMap<string, int> demoData;
ASSERT_TRUE(demoData.IsEmpty());
// insert new
demoData.Insert("A", 1);
ASSERT_FALSE(demoData.IsEmpty());
ASSERT_EQ(demoData.Size(), 1);
// insert copy one should fail
ASSERT_FALSE(demoData.Insert("A", 2));
ASSERT_EQ(demoData.Size(), 1);
}
/*
* @tc.name: testUtilsNormalFeatureEnsureInsert001
* @tc.desc: SafeMap
*/
HWTEST_F(UtilsSafeMap, testUtilsNormalFeatureEnsureInsert001, TestSize.Level1)
{
SafeMap<string, int> demoData;
ASSERT_TRUE(demoData.IsEmpty());
demoData.Insert("A", 1);
demoData.EnsureInsert("B", 2);
ASSERT_FALSE(demoData.IsEmpty());
ASSERT_EQ(demoData.Size(), 2);
// insert copy one and new one
demoData.EnsureInsert("B", 5);
demoData.EnsureInsert("C", 6);
ASSERT_EQ(demoData.Size(), 3);
}
/*
* @tc.name: testUtilsNormalFeatureFind001
* @tc.desc: SafeMap
*/
HWTEST_F(UtilsSafeMap, testUtilsNormalFeatureFind001, TestSize.Level1)
{
SafeMap<string, int> demoData;
ASSERT_TRUE(demoData.IsEmpty());
demoData.Insert("A", 1);
demoData.Insert("B", 10000);
demoData.EnsureInsert("B", 2);
demoData.EnsureInsert("C", 6);
ASSERT_FALSE(demoData.IsEmpty());
ASSERT_EQ(demoData.Size(), 3);
int i = 0;
ASSERT_TRUE(demoData.Find("A", i));
ASSERT_EQ(i, 1);
ASSERT_TRUE(demoData.Find("B", i));
ASSERT_EQ(i, 2);
ASSERT_TRUE(demoData.Find("C", i));
ASSERT_EQ(i, 6);
}
/*
* @tc.name: testUtilsNormalFeatureFindAndSet001
* @tc.desc: SafeMap
*/
HWTEST_F(UtilsSafeMap, testUtilsNormalFeatureFindAndSet001, TestSize.Level1)
{
SafeMap<string, int> demoData;
ASSERT_TRUE(demoData.IsEmpty());
demoData.Insert("A", 1);
demoData.EnsureInsert("B", 2);
int oldvalue = 0;
int newvalue = 3;
ASSERT_TRUE(demoData.FindOldAndSetNew("A", oldvalue, newvalue));
// old value
ASSERT_EQ(oldvalue, 1);
newvalue = 4;
ASSERT_TRUE(demoData.FindOldAndSetNew("B", oldvalue, newvalue));
// old value
ASSERT_EQ(oldvalue, 2);
int i = -1;
ASSERT_TRUE(demoData.Find("A", i));
// new value
ASSERT_EQ(i, 3);
ASSERT_TRUE(demoData.Find("B", i));
// new value
ASSERT_EQ(i, 4);
}
/*
* @tc.name: testUtilsNormalFeatureEraseAndClear001
* @tc.desc: SafeMap
*/
HWTEST_F(UtilsSafeMap, testUtilsNormalFeatureEraseAndClear001, TestSize.Level1)
{
SafeMap<string, int> demoData;
ASSERT_TRUE(demoData.IsEmpty());
demoData.Insert("A", 1);
demoData.EnsureInsert("B", 2);
ASSERT_EQ(demoData.Size(), 2);
demoData.Erase("A");
ASSERT_EQ(demoData.Size(), 1);
demoData.Clear();
ASSERT_EQ(demoData.Size(), 0);
}
/*
* @tc.name: testUtilsConcurrentWriteAndRead001
* @tc.desc: 100 threads test in writein to the same key of the map, while read at same time and no throw
*/
const int THREAD_NUM = 100;
HWTEST_F(UtilsSafeMap, testUtilsConcurrentWriteAndRead001, TestSize.Level1)
{
SafeMap<string, int> demoData;
std::thread threads[THREAD_NUM];
std::thread checkThread[THREAD_NUM];
ASSERT_NO_THROW({
auto lamfuncInsert = [](SafeMap<string, int>& data, const string& key,
const int& value, const std::chrono::system_clock::time_point& absTime) {
std::this_thread::sleep_until(absTime);
data.EnsureInsert(key, value);
};
auto lamfuncCheck = [](SafeMap<string, int>& data, const string& key,
std::chrono::system_clock::time_point absTime) {
std::this_thread::sleep_until(absTime);
thread_local int i = -1;
data.Find(key, i);
};
using std::chrono::system_clock;
std::time_t timeT = system_clock::to_time_t(system_clock::now());
timeT += 2;
string key("A");
for (int i = 0; i < THREAD_NUM; ++i) {
threads[i] = std::thread(lamfuncInsert, std::ref(demoData), key, i, system_clock::from_time_t(timeT));
checkThread[i] = std::thread(lamfuncCheck, std::ref(demoData), key, system_clock::from_time_t(timeT));
}
std::this_thread::sleep_for(std::chrono::seconds(3));
for (auto& t : threads) {
t.join();
}
for (auto& t : checkThread) {
t.join();
}
});
}
/*
* @tc.name: testUtilsConcurrentWriteAndFind001
* @tc.desc: 100 threads test in writein to the corresponding key of the map,
* while read at same time and check the results
*/
HWTEST_F(UtilsSafeMap, testUtilsConcurrentWriteAndFind001, TestSize.Level1)
{
SafeMap<string, int> demoData;
std::thread threads[THREAD_NUM];
std::vector<std::future<int>> vcfi;
ASSERT_NO_THROW({
auto lamfuncInsert = [](SafeMap<string, int>& data, const string& key,
const int& value, const std::chrono::system_clock::time_point& absTime) {
std::this_thread::sleep_until(absTime);
data.EnsureInsert(key, value);
};
auto lamfuncCheckLoop = [](SafeMap<string, int>& data, const string& key,
std::chrono::system_clock::time_point absTime) {
std::this_thread::sleep_until(absTime);
thread_local int i = -1;
while (!data.Find(key, i)) {
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
return i;
};
using std::chrono::system_clock;
std::time_t timeT = system_clock::to_time_t(system_clock::now());
timeT += 2;
string key("A");
for (int i = 0; i < THREAD_NUM; ++i) {
threads[i] = std::thread(lamfuncInsert, std::ref(demoData),
key + std::to_string(i), i, system_clock::from_time_t(timeT));
vcfi.push_back(std::async(std::launch::async, lamfuncCheckLoop,
std::ref(demoData), key + std::to_string(i), system_clock::from_time_t(timeT)));
}
std::this_thread::sleep_for(std::chrono::seconds(4));
for (auto& t : threads) {
t.join();
}
vector<int> result;
for (auto& t : vcfi) {
result.push_back(t.get());
}
std::sort(result.begin(), result.end());
for (int i = 0; i < THREAD_NUM; ++i) {
ASSERT_EQ(i, result[i]);
}
});
}
/*
* @tc.name: testUtilsConcurrentWriteAndFindAndSet001
* @tc.desc: 100 threads test in writein to the corresponding key of the map,
* while findandfix at same time and check the results
*/
HWTEST_F(UtilsSafeMap, testUtilsConcurrentWriteAndFindAndSet001, TestSize.Level1)
{
SafeMap<string, int> demoData;
std::thread threads[THREAD_NUM];
std::vector<std::future<int>> vcfi;
ASSERT_NO_THROW({
auto lamfuncInsert = [](SafeMap<string, int>& data, const string& key,
const int& value, const std::chrono::system_clock::time_point& absTime) {
std::this_thread::sleep_until(absTime);
data.EnsureInsert(key, value);
};
auto lamfuncCheckLoop = [](SafeMap<string, int>& data, const string& key,
const int& newvalue, std::chrono::system_clock::time_point absTime) {
std::this_thread::sleep_until(absTime);
thread_local int i = -1;
while (!data.FindOldAndSetNew(key, i, newvalue)) {
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
return i;
};
using std::chrono::system_clock;
std::time_t timeT = system_clock::to_time_t(system_clock::now());
timeT += 2;
string key("A");
for (int i = 0; i < THREAD_NUM; ++i) {
threads[i] = std::thread(lamfuncInsert, std::ref(demoData),
key + std::to_string(i), i, system_clock::from_time_t(timeT));
vcfi.push_back(std::async(std::launch::async, lamfuncCheckLoop,
std::ref(demoData), key + std::to_string(i), i + 1, system_clock::from_time_t(timeT)));
}
std::this_thread::sleep_for(std::chrono::seconds(4));
for (auto& t : threads) {
t.join();
}
vector<int> result;
for (auto& t : vcfi) {
result.push_back(t.get());
}
std::sort(result.begin(), result.end());
for (int i = 0; i < THREAD_NUM; ++i) {
ASSERT_EQ(i, result[i]);
}
int t = 0;
result.clear();
for (int i = 0; i < THREAD_NUM; ++i) {
t = -1;
ASSERT_TRUE(demoData.Find("A" + std::to_string(i), t));
result.push_back(t);
}
std::sort(result.begin(), result.end());
for (int i = 0; i < THREAD_NUM; ++i) {
ASSERT_EQ(i + 1, result[i]);
}
});
}
原文链接:https://blog.csdn.net/caoshangpa/article/details/78557421
更多推荐
所有评论(0)