Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lockfree get next bit for bitvec #221

Merged
merged 4 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions include/cista/containers/bitvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#include <cassert>
#include <cinttypes>
#include <atomic>
#include <iosfwd>
#include <limits>
#include <numeric>
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
Expand Down Expand Up @@ -121,6 +123,69 @@ struct basic_bitvec {
check_block(blocks_.size() - 1, sanitized_last_block());
}

std::optional<Key> next_set_bit(size_type const i) const {
if (i >= size()) {
return std::nullopt;
}

auto const first_block_idx = i / bits_per_block;
auto const first_block = blocks_[first_block_idx];
if (first_block != 0U) {
auto const first_bit = i % bits_per_block;
auto const n = std::min(size(), bits_per_block);
for (auto bit = first_bit; bit != n; ++bit) {
if ((first_block & (block_t{1U} << bit)) != 0U) {
return Key{first_block_idx * bits_per_block + bit};
}
}
}

if (first_block_idx + 1U == blocks_.size()) {
return std::nullopt;
}

auto const check_block = [&](size_type const block_idx,
block_t const block) -> std::optional<Key> {
if (block != 0U) {
for (auto bit = size_type{0U}; bit != bits_per_block; ++bit) {
if ((block & (block_t{1U} << bit)) != 0U) {
return Key{block_idx * bits_per_block + bit};
}
}
}
return std::nullopt;
};

for (auto block_idx = first_block_idx + 1U; block_idx != blocks_.size() - 1;
++block_idx) {
if (auto const set_bit_idx = check_block(block_idx, blocks_[block_idx]);
set_bit_idx.has_value()) {
return set_bit_idx;
}
}

if (auto const set_bit_idx =
check_block(blocks_.size() - 1, sanitized_last_block());
set_bit_idx.has_value()) {
return set_bit_idx;
}

return std::nullopt;
}

std::optional<Key> get_next(std::atomic_size_t& next) const {
while (true) {
auto expected = next.load();
auto idx = next_set_bit(Key{static_cast<base_t<Key>>(expected)});
if (!idx.has_value()) {
return std::nullopt;
}
if (next.compare_exchange_weak(expected, *idx + 1U)) {
return idx;
}
}
}

size_type size() const noexcept { return size_; }
bool empty() const noexcept { return size() == 0U; }

Expand Down
63 changes: 63 additions & 0 deletions test/bitvec_test.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "doctest.h"

#include <sstream>
#include <thread>
#include <vector>

#ifdef SINGLE_HEADER
Expand Down Expand Up @@ -87,3 +88,65 @@ TEST_CASE("bitvec less than") {
auto const uut_lt = bitvec_lt(uut1, uut2);
CHECK(ref_lt == uut_lt);
}

unsigned long get_random_number() { // period 2^96-1
static std::uint64_t x = 123456789, y = 362436069, z = 521288629;

unsigned long t;
x ^= x << 16;
x ^= x >> 5;
x ^= x << 1;

t = x;
x = y;
y = z;
z = t ^ x ^ y;

return z;
}

TEST_CASE("bitvec parallel") {
constexpr auto const kBits = 1'000'000U;
constexpr auto const kWorkers = 100U;

auto b = cista::raw::bitvec{};
b.resize(kBits);

auto bits = std::vector<std::size_t>{};
bits.resize(b.size() * 0.2);
std::generate(begin(bits), end(bits), [&]() {
auto x = static_cast<std::uint32_t>(get_random_number() % b.size());
b.set(x, true);
return x;
});
std::sort(begin(bits), end(bits));
bits.erase(std::unique(begin(bits), end(bits)), end(bits));

auto next = std::atomic_size_t{0U};
auto workers = std::vector<std::thread>(kWorkers);
auto collected_bits = std::vector<std::vector<std::size_t>>(kWorkers);
for (auto i = 0U; i != kWorkers; ++i) {
workers[i] = std::thread{[&, i]() {
auto next_bit = std::optional<std::size_t>{};
do {
next_bit = b.get_next(next);
if (next_bit.has_value()) {
collected_bits[i].push_back(*next_bit);
}
} while (next_bit.has_value());
}};
}

for (auto& w : workers) {
w.join();
}

auto check = std::vector<std::size_t>{};
for (auto& x : collected_bits) {
check.insert(end(check), begin(x), end(x));
}
std::sort(begin(check), end(check));
check.erase(std::unique(begin(check), end(check)), end(check));

CHECK_EQ(bits, check);
}