Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safeguard against link overflows in ConcurrentHashMap #2107

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions folly/concurrency/detail/ConcurrentHashMap-detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class NodeT : public hazptr_obj_base_linked<
this->set_deleter( // defined in hazptr_obj
concurrenthashmap::HazptrDeleter<Allocator>());
this->set_cohort_tag(cohort); // defined in hazptr_obj
this->acquire_link_safe(); // defined in hazptr_obj_base_linked
CHECK(this->acquire_link_safe()); // defined in hazptr_obj_base_linked
}

ValueHolder<KeyType, ValueType, Allocator, Atom> item_;
Expand Down Expand Up @@ -361,7 +361,11 @@ class alignas(64) BucketTable {
}
}
// Set longest last run in new bucket, incrementing the refcount.
lastrun->acquire_link(); // defined in hazptr_obj_base_linked
while (!lastrun->acquire_link()) {
cohort->cleanup();
std::this_thread::yield();
}

newbuckets->buckets_[lastidx]().store(lastrun, std::memory_order_relaxed);
// Clone remaining nodes
for (; node != lastrun;
Expand Down Expand Up @@ -407,7 +411,12 @@ class alignas(64) BucketTable {
}

template <typename K, typename MatchFunc>
std::size_t erase(size_t h, const K& key, Iterator* iter, MatchFunc match) {
std::size_t erase(
size_t h,
const K& key,
Iterator* iter,
MatchFunc match,
hazptr_obj_cohort<Atom>* cohort) {
Node* node{nullptr};
{
std::lock_guard<Mutex> g(m_);
Expand All @@ -426,7 +435,10 @@ class alignas(64) BucketTable {
}
auto next = node->next_.load(std::memory_order_relaxed);
if (next) {
next->acquire_link(); // defined in hazptr_obj_base_linked
while (!next->acquire_link()) {
cohort->cleanup();
std::this_thread::yield();
} // defined in hazptr_obj_base_linked
}
if (prev) {
prev->next_.store(next, std::memory_order_release);
Expand Down Expand Up @@ -709,7 +721,10 @@ class alignas(64) BucketTable {
auto next = node->next_.load(std::memory_order_relaxed);
cur->next_.store(next, std::memory_order_relaxed);
if (next) {
next->acquire_link(); // defined in hazptr_obj_base_linked
while (!next->acquire_link()) {
cohort->cleanup();
std::this_thread::yield();
} // defined in hazptr_obj_base_linked
}
prev->store(cur, std::memory_order_release);
it.setNode(cur, buckets, bcount, idx);
Expand Down Expand Up @@ -1347,7 +1362,12 @@ class alignas(64) SIMDTable {
}

template <typename K, typename MatchFunc>
std::size_t erase(size_t h, const K& key, Iterator* iter, MatchFunc match) {
std::size_t erase(
size_t h,
const K& key,
Iterator* iter,
MatchFunc match,
hazptr_obj_cohort<Atom>* /* cohort */) {
const HashPair hp = splitHash(h);

std::unique_lock<Mutex> g(m_);
Expand Down Expand Up @@ -1880,7 +1900,7 @@ class alignas(64) ConcurrentHashMapSegment {
template <typename K, typename MatchFunc>
size_type erase_internal(
size_t h, const K& key, Iterator* iter, MatchFunc match) {
return impl_.erase(h, key, iter, match);
return impl_.erase(h, key, iter, match, cohort_);
}

// Unfortunately because we are reusing nodes on rehash, we can't
Expand Down
52 changes: 51 additions & 1 deletion folly/concurrency/test/ConcurrentHashMapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
#include <folly/concurrency/ConcurrentHashMap.h>

#include <atomic>
#include <latch>
#include <limits>
#include <memory>
#include <thread>
#include <vector>

#include <folly/Traits.h>
#include <folly/container/test/TrackingTypes.h>
Expand Down Expand Up @@ -1131,6 +1134,52 @@ TYPED_TEST_P(ConcurrentHashMapTest, ConcurrentInsertClear) {
}
}

TYPED_TEST_P(ConcurrentHashMapTest, StressTestReclamation) {
// Create a map where we keep reclaiming a lot of objects that are linked to
// one node.

// Ensure all entries are mapped to a single segment.
auto constant_hash = [](unsigned long) -> uint64_t { return 0; };
CHM<unsigned long, unsigned long, decltype(constant_hash)> map;
static constexpr unsigned long key_prev =
0; // A key that the test key has a link to - to guard against immediate
// reclamation.
static constexpr unsigned long key_test =
1; // A key that keeps being reclaimed repeatedly.
static constexpr unsigned long key_link_explosion =
2; // A key that is linked to the test key.

EXPECT_TRUE(map.insert(std::make_pair(key_prev, 0)).second);
EXPECT_TRUE(map.insert(std::make_pair(key_test, 0)).second);
EXPECT_TRUE(map.insert(std::make_pair(key_link_explosion, 0)).second);

std::vector<std::thread> threads;
// The number of links are stored as a uint16_t, so having 65K threads should
// cause sufficient racing
static constexpr uint64_t num_threads = std::numeric_limits<uint16_t>::max();
static constexpr uint64_t iters = 100;
std::latch start{num_threads};
for (uint64_t t = 0; t < num_threads; t++) {
threads.push_back(lib::thread([t, &map, &start]() {
start.arrive_and_wait();
static constexpr uint64_t progress_report_pct =
(iters / 20); // Every 5% we log progress
for (uint64_t i = 0; i < iters; i++) {
if (t == 0 && (i % progress_report_pct) == 0) {
// To a casual observer - to know that the test is progressing, even
// if slowly
LOG(INFO) << "Progress: " << (i * 100 / iters);
}

map.insert_or_assign(key_test, i * num_threads);
}
}));
}
for (auto& t : threads) {
join;
}
}

REGISTER_TYPED_TEST_SUITE_P(
ConcurrentHashMapTest,
MapTest,
Expand Down Expand Up @@ -1174,7 +1223,8 @@ REGISTER_TYPED_TEST_SUITE_P(
HeterogeneousInsert,
InsertOrAssignIterator,
EraseClonedNonCopyable,
ConcurrentInsertClear);
ConcurrentInsertClear,
StressTestReclamation);

using folly::detail::concurrenthashmap::bucket::BucketTable;

Expand Down
15 changes: 9 additions & 6 deletions folly/synchronization/HazptrDomain.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,15 @@ class hazptr_domain {
int count = 0;
for (int s = 0; s < kNumShards; ++s) {
if (tagged[s]) {
ObjList match, nomatch;
list_match_condition(tagged[s], match, nomatch, [&](Obj* o) {
return hs.count(o->raw_ptr()) > 0;
});
ObjList nomatch;
{
ObjList match;
list_match_condition(tagged[s], match, nomatch, [&](Obj* o) {
return hs.count(o->raw_ptr()) > 0;
});
List l(match.head(), match.tail());
tagged_[s].push_unlock(l);
}
count += nomatch.count();
auto obj = nomatch.head();
while (obj) {
Expand All @@ -480,8 +485,6 @@ class hazptr_domain {
cohort->push_safe_obj(obj);
obj = next;
}
List l(match.head(), match.tail());
tagged_[s].push_unlock(l);
}
}
return count;
Expand Down
7 changes: 7 additions & 0 deletions folly/synchronization/HazptrObj.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,13 @@ class hazptr_obj_cohort {
DCHECK(l_.empty());
}

/** force reclaiming any items that are retired. Executes reclamation */
void cleanup() {
check_threshold_push();
default_hazptr_domain<Atom>().cleanup();
reclaim_safe_list();
}

private:
friend class hazptr_domain<Atom>;
friend class hazptr_obj<Atom>;
Expand Down
30 changes: 19 additions & 11 deletions folly/synchronization/HazptrObjLinked.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ class hazptr_obj_linked : public hazptr_obj<Atom> {
Atom<Count> count_{0};

public:
void acquire_link() noexcept { count_inc(kLink); }
bool acquire_link() noexcept { return count_inc(kLink, kLinkMask); }

void acquire_link_safe() noexcept { count_inc_safe(kLink); }
bool acquire_link_safe() noexcept { return count_inc_safe(kLink, kLinkMask); }

void acquire_ref() noexcept { count_inc(kRef); }
bool acquire_ref() noexcept { return count_inc(kRef, kRefMask); }

void acquire_ref_safe() noexcept { count_inc_safe(kRef); }
bool acquire_ref_safe() noexcept { return count_inc_safe(kRef, kRefMask); }

private:
template <typename, template <typename> class, typename>
Expand All @@ -116,17 +116,25 @@ class hazptr_obj_linked : public hazptr_obj<Atom> {
count_.store(val, std::memory_order_release);
}

void count_inc(Count add) noexcept {
auto oldval = count_.fetch_add(add, std::memory_order_acq_rel);
DCHECK_LT(oldval & kLinkMask, kLinkMask);
DCHECK_LT(oldval & kRefMask, kRefMask);
bool count_inc(Count add, Count mask) noexcept {
Count oldval = count();
while (true) {
if ((oldval & mask) == mask) {
return false;
}
if (count_cas(oldval, oldval + add)) {
return true;
}
}
}

void count_inc_safe(Count add) noexcept {
bool count_inc_safe(Count add, Count mask) noexcept {
auto oldval = count();
if ((oldval & mask) == mask) {
return false;
}
count_set(oldval + add);
DCHECK_LT(oldval & kLinkMask, kLinkMask);
DCHECK_LT(oldval & kRefMask, kRefMask);
return true;
}

bool count_cas(Count& oldval, Count newval) noexcept {
Expand Down
Loading