Skip to content

Commit

Permalink
Clean up ref implementations with has_payload flag (NVIDIA#368)
Browse files Browse the repository at this point in the history
NVIDIA#356 introduces the `HasPayload` template boolean to distinguish code
paths between map and set implementations thus the key input for base
ref insert functions becomes redundant. This PR cleans up the base ref
implementations by removing the key input and fixes a logical issue in
NVIDIA#356: set doesn't have payload while map has.
  • Loading branch information
PointKernel committed Sep 15, 2023
1 parent dcd5a99 commit 0cd4da0
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 33 deletions.
55 changes: 38 additions & 17 deletions include/cuco/detail/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,23 @@ class open_addressing_ref_impl {
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param key Key of the element to insert
* @param value The element to insert
* @param predicate Predicate used to compare slot content against `key`
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Predicate>
__device__ bool insert(key_type const& key,
value_type const& value,
Predicate const& predicate) noexcept
__device__ bool insert(value_type const& value, Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const key = [&]() {
if constexpr (HasPayload) {
return value.first;
} else {
return value;
}
}();
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());

while (true) {
Expand Down Expand Up @@ -202,18 +207,23 @@ class open_addressing_ref_impl {
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert
* @param key Key of the element to insert
* @param value The element to insert
* @param predicate Predicate used to compare slot content against `key`
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Predicate>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
key_type const& key,
value_type const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
if constexpr (HasPayload) {
return value.first;
} else {
return value;
}
}();
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

while (true) {
Expand Down Expand Up @@ -269,19 +279,25 @@ class open_addressing_ref_impl {
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Predicate Predicate type
*
* @param key Key of the element to insert
* @param value The element to insert
* @param predicate Predicate used to compare slot content against `key`
*
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(key_type const& key,
value_type const& value,
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value,
Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const key = [&]() {
if constexpr (HasPayload) {
return value.first;
} else {
return value;
}
}();
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());

while (true) {
Expand Down Expand Up @@ -326,7 +342,6 @@ class open_addressing_ref_impl {
* @tparam Predicate Predicate type
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param key Key of the element to insert
* @param value The element to insert
* @param predicate Predicate used to compare slot content against `key`
*
Expand All @@ -336,10 +351,16 @@ class open_addressing_ref_impl {
template <bool HasPayload, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
key_type const& key,
value_type const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
if constexpr (HasPayload) {
return value.first;
} else {
return value;
}
}();
auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent());

while (true) {
Expand Down Expand Up @@ -710,11 +731,11 @@ class open_addressing_ref_impl {
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (HasPayload) {
// If it's a set implementation, compare the whole slot content
return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_);
} else {
// If it's a map implementation, compare keys only
return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first);
} else {
// If it's a set implementation, compare the whole slot content
return cuco::detail::bitwise_compare(*old_ptr, this->empty_slot_sentinel_);
}
}();
if (inserted) {
Expand All @@ -723,11 +744,11 @@ class open_addressing_ref_impl {
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
auto const res = [&]() {
if constexpr (HasPayload) {
// If it's a set implementation, compare the whole slot content
return predicate.equal_to(*old_ptr, value);
} else {
// If it's a map implementation, compare keys only
return predicate.equal_to(old_ptr->first, value.first);
} else {
// If it's a set implementation, compare the whole slot content
return predicate.equal_to(*old_ptr, value);
}
}();
return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE
Expand Down
16 changes: 8 additions & 8 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ class operator_impl<
__device__ bool insert(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(value.first, value, ref_.predicate_);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(value, ref_.predicate_);
}

/**
Expand All @@ -225,8 +225,8 @@ class operator_impl<
value_type const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(group, value.first, value, ref_.predicate_);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(group, value, ref_.predicate_);
}
};

Expand Down Expand Up @@ -454,8 +454,8 @@ class operator_impl<
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(value.first, value, ref_.predicate_);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(value, ref_.predicate_);
}

/**
Expand All @@ -475,8 +475,8 @@ class operator_impl<
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(group, value.first, value, ref_.predicate_);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(group, value, ref_.predicate_);
}
};

Expand Down
16 changes: 8 additions & 8 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class operator_impl<op::insert_tag,
__device__ bool insert(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(value, value, ref_.predicate_);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(value, ref_.predicate_);
}

/**
Expand All @@ -117,8 +117,8 @@ class operator_impl<op::insert_tag,
value_type const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(group, value, value, ref_.predicate_);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(group, value, ref_.predicate_);
}
};

Expand Down Expand Up @@ -182,8 +182,8 @@ class operator_impl<op::insert_and_find_tag,
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(value, value, ref_.predicate_);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(value, ref_.predicate_);
}

/**
Expand All @@ -203,8 +203,8 @@ class operator_impl<op::insert_and_find_tag,
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(group, value, value, ref_.predicate_);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(group, value, ref_.predicate_);
}
};

Expand Down

0 comments on commit 0cd4da0

Please sign in to comment.