Skip to content

Commit

Permalink
Add unit tests for DeviceSelect::FlaggedIf
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Mar 20, 2024
1 parent 973d7b7 commit 4c01521
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 128 deletions.
41 changes: 0 additions & 41 deletions cub/cub/device/device_select.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -736,29 +736,6 @@ struct DeviceSelect
stream);
}

template <typename InputIteratorT,
typename FlagIterator,
typename OutputIteratorT,
typename NumSelectedIteratorT,
typename SelectOp>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t FlaggedIf(
void* d_temp_storage,
size_t& temp_storage_bytes,
InputIteratorT d_in,
FlagIterator d_flags,
OutputIteratorT d_out,
NumSelectedIteratorT d_num_selected_out,
int num_items,
SelectOp select_op,
cudaStream_t stream,
bool debug_synchronous)
{
CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG

return FlaggedIf<InputIteratorT, FlagIterator, OutputIteratorT, NumSelectedIteratorT, SelectOp>(
d_temp_storage, temp_storage_bytes, d_in, d_flags, d_out, d_num_selected_out, num_items, select_op, stream);
}

template <typename IteratorT, typename FlagIterator, typename NumSelectedIteratorT, typename SelectOp>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t FlaggedIf(
void* d_temp_storage,
Expand Down Expand Up @@ -796,24 +773,6 @@ struct DeviceSelect
stream);
}

template <typename IteratorT, typename FlagIterator, typename NumSelectedIteratorT, typename SelectOp>
CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t FlaggedIf(
void* d_temp_storage,
size_t& temp_storage_bytes,
IteratorT d_data,
FlagIterator d_flags,
NumSelectedIteratorT d_num_selected_out,
int num_items,
SelectOp select_op,
cudaStream_t stream,
bool debug_synchronous)
{
CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG

return FlaggedIf<IteratorT, FlagIterator, NumSelectedIteratorT, SelectOp>(
d_temp_storage, temp_storage_bytes, d_data, d_flags, d_num_selected_out, num_items, select_op, stream);
}

//! @rst
//! Given an input sequence ``d_in`` having runs of consecutive equal-valued keys,
//! only the first key from each run is selectively copied to ``d_out``.
Expand Down
167 changes: 80 additions & 87 deletions cub/test/catch2_test_device_select_flagged_if.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,28 @@
#include <cub/device/device_select.cuh>

#include <thrust/count.h>
#include <thrust/logical.h>

// #include "catch2_test_helper.h"
#include "catch2_test_helper.h"
#include "catch2_test_launch_helper.h"
#include "thrust/functional.h"

template <class T, class FlagT, class Pred>
static c2h::host_vector<T>
get_reference(const c2h::device_vector<T>& in, const c2h::device_vector<FlagT>& flags, Pred const& if_predicate)
get_reference(c2h::device_vector<T> const& in, c2h::device_vector<FlagT> const& flags, Pred if_predicate)
{
struct selector
{
const T* ref_begin = nullptr;
const FlagT* flag_begin = nullptr;
T const* ref_begin = nullptr;
FlagT const* flag_begin = nullptr;
Pred const& if_pred;

constexpr selector(const T* ref, const FlagT* flag, Pred const& pred) noexcept
constexpr selector(T const* ref, FlagT const* flag, Pred const& pred) noexcept
: ref_begin(ref)
, flag_begin(flag)
, if_pred(pred)
{}

bool operator()(const T& val) const
bool operator()(T const& val) const
{
const auto pos = &val - ref_begin;
return static_cast<bool>(if_pred(flag_begin[pos]));
Expand All @@ -70,26 +70,32 @@ DECLARE_LAUNCH_WRAPPER(cub::DeviceSelect::FlaggedIf, select_flagged_if);

// %PARAM% TEST_LAUNCH lid 0:1:2

struct is_even
using custom_t = c2h::custom_type_t<c2h::equal_comparable_t>;

template <typename T>
struct is_even_t
{
__host__ __device__ bool operator()(int const& elem) const
__host__ __device__ bool operator()(T const& elem) const
{
return !(elem % 2);
}
};

template <typename T>
struct less_than_t
template <>
struct is_even_t<custom_t>
{
T compare;

explicit __host__ less_than_t(T compare)
: compare(compare)
{}
__host__ __device__ bool operator()(custom_t elem) const
{
return !(elem.key % 2);
}
};

struct equal_to_default_t
{
template <typename T>
__host__ __device__ bool operator()(const T& a) const
{
return a < compare;
return a == T{};
}
};

Expand All @@ -112,18 +118,11 @@ struct always_true_t
};

using all_types =
c2h::type_list<std::uint8_t,
std::uint16_t,
std::uint32_t,
std::uint64_t,
ulonglong2,
ulonglong4,
int,
long2,
c2h::custom_type_t<c2h::equal_comparable_t>>;

using types = c2h::
type_list<std::uint8_t, std::uint32_t, ulonglong4, c2h::custom_type_t<c2h::less_comparable_t, c2h::equal_comparable_t>>;
c2h::type_list<std::uint8_t, std::uint16_t, std::uint32_t, std::uint64_t, ulonglong2, ulonglong4, int, long2, custom_t>;

using types = c2h::type_list<std::uint8_t, std::uint32_t, ulonglong4, custom_t>;

using flag_types = c2h::type_list<std::uint8_t, std::uint64_t, custom_t>;

CUB_TEST("DeviceSelect::FlaggedIf can run with empty input", "[device][select_flagged_if]", types)
{
Expand Down Expand Up @@ -183,109 +182,103 @@ CUB_TEST("DeviceSelect::FlaggedIf handles no matched", "[device][select_flagged_
REQUIRE(num_selected_out[0] == 0);
}

CUB_TEST("DeviceSelect::FlaggedIf does not change input", "[device][select_flagged_if]", types)
CUB_TEST("DeviceSelect::FlaggedIf does not change input and is stable",
"[device][select_flagged_if]",
c2h::type_list<std::uint8_t, std::uint64_t>,
flag_types)
{
using type = typename c2h::get<0, TestType>;
using input_type = typename c2h::get<0, TestType>;
using flag_type = typename c2h::get<1, TestType>;

const int num_items = GENERATE_COPY(take(2, random(1, 1000000)));
c2h::device_vector<type> in(num_items);
c2h::device_vector<type> out(num_items);
c2h::device_vector<input_type> in(num_items);
c2h::device_vector<input_type> out(num_items);
c2h::gen(CUB_SEED(2), in);

c2h::device_vector<int> flags(num_items);
c2h::gen(CUB_SEED(1), flags, 0, 1);
const int num_selected = static_cast<int>(thrust::count(c2h::device_policy, flags.begin(), flags.end(), 0));
is_even_t<flag_type> is_even{};

c2h::device_vector<flag_type> flags(num_items);
c2h::gen(CUB_SEED(1), flags);
const int num_selected = static_cast<int>(thrust::count_if(c2h::device_policy, flags.begin(), flags.end(), is_even));
const c2h::host_vector<input_type> reference_out = get_reference(in, flags, is_even);

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
int* d_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());

// copy input first
c2h::device_vector<type> reference = in;
c2h::device_vector<input_type> reference_in = in;

select_flagged_if(in.begin(), flags.begin(), out.begin(), d_num_selected_out, num_items, is_even{});
select_flagged_if(in.begin(), flags.begin(), out.begin(), d_num_selected_out, num_items, is_even);

REQUIRE(num_selected == num_selected_out[0]);
REQUIRE(reference == in);
}
REQUIRE(reference_in == in);

CUB_TEST("DeviceSelect::FlaggedIf is stable",
"[device][select_flagged_if]",
c2h::type_list<c2h::custom_type_t<c2h::equal_comparable_t>>)
{
using type = typename c2h::get<0, TestType>;

const int num_items = GENERATE_COPY(take(2, random(1, 1000000)));
c2h::device_vector<type> in(num_items);
c2h::device_vector<type> out(num_items);
c2h::gen(CUB_SEED(2), in);

c2h::device_vector<int> flags(num_items);
c2h::gen(CUB_SEED(1), flags, 0, 1);
const int num_selected = static_cast<int>(thrust::count(c2h::device_policy, flags.begin(), flags.end(), 0));
const c2h::host_vector<type> reference = get_reference(in, flags, is_even{});

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
int* d_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());

select_flagged_if(in.begin(), flags.begin(), out.begin(), d_num_selected_out, num_items, is_even{});
// Ensure that we did not overwrite other elements
const auto boundary = out.begin() + num_selected_out[0];
REQUIRE(thrust::all_of(c2h::device_policy, boundary, out.end(), equal_to_default_t{}));

out.resize(num_selected_out[0]);
REQUIRE(num_selected == num_selected_out[0]);
REQUIRE(reference == out);
REQUIRE(reference_out == out);
}

CUB_TEST("DeviceSelect::FlaggedIf works with iterators", "[device][select_flagged_if]", all_types)
CUB_TEST("DeviceSelect::If works with iterators", "[device][select_if]", all_types, flag_types)
{
using type = typename c2h::get<0, TestType>;
using input_type = typename c2h::get<0, TestType>;
using flag_type = typename c2h::get<1, TestType>;

const int num_items = GENERATE_COPY(take(2, random(1, 1000000)));
c2h::device_vector<type> in(num_items);
c2h::device_vector<type> out(num_items);
c2h::device_vector<input_type> in(num_items);
c2h::device_vector<input_type> out(num_items);
c2h::gen(CUB_SEED(2), in);

c2h::device_vector<int> flags(num_items);
c2h::gen(CUB_SEED(1), flags, 0, 1);
const int num_selected = static_cast<int>(thrust::count(c2h::device_policy, flags.begin(), flags.end(), 0));
const c2h::host_vector<type> reference = get_reference(in, flags, is_even{});
is_even_t<flag_type> is_even{};

c2h::device_vector<flag_type> flags(num_items);
c2h::gen(CUB_SEED(1), flags);
const int num_selected = static_cast<int>(thrust::count_if(c2h::device_policy, flags.begin(), flags.end(), is_even));
const c2h::host_vector<input_type> reference = get_reference(in, flags, is_even);

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
int* d_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());
int* d_first_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());

select_flagged_if(in.data(), flags.begin(), out.data(), d_num_selected_out, num_items, is_even{});
select_flagged_if(in.begin(), flags.begin(), out.begin(), d_first_num_selected_out, num_items, is_even);

out.resize(num_selected_out[0]);
REQUIRE(num_selected == num_selected_out[0]);
REQUIRE(reference == out);
}

CUB_TEST("DeviceSelect::FlaggedIf works with pointers", "[device][select_flagged_if]", types)
CUB_TEST("DeviceSelect::Flagged works with pointers", "[device][select_flagged]", types, flag_types)
{
using type = typename c2h::get<0, TestType>;
using input_type = typename c2h::get<0, TestType>;
using flag_type = typename c2h::get<1, TestType>;

const int num_items = GENERATE_COPY(take(2, random(1, 1000000)));
c2h::device_vector<type> in(num_items);
c2h::device_vector<type> out(num_items);
c2h::device_vector<input_type> in(num_items);
c2h::device_vector<input_type> out(num_items);
c2h::gen(CUB_SEED(2), in);

c2h::device_vector<int> flags(num_items);
c2h::gen(CUB_SEED(1), flags, 0, 1);
is_even_t<flag_type> is_even{};

c2h::device_vector<flag_type> flags(num_items);
c2h::gen(CUB_SEED(1), flags);

const int num_selected = static_cast<int>(thrust::count(c2h::device_policy, flags.begin(), flags.end(), 0));
const c2h::host_vector<type> reference = get_reference(in, flags, is_even{});
const int num_selected = static_cast<int>(thrust::count_if(c2h::device_policy, flags.begin(), flags.end(), is_even));
const c2h::host_vector<input_type> reference = get_reference(in, flags, is_even);

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
int *d_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());

select_flagged_if(thrust::raw_pointer_cast(in.data()),
thrust::raw_pointer_cast(flags.data()),
thrust::raw_pointer_cast(out.data()),
d_num_selected_out,
num_items,
is_even{});
select_flagged_if(
thrust::raw_pointer_cast(in.data()),
thrust::raw_pointer_cast(flags.data()),
thrust::raw_pointer_cast(out.data()),
d_num_selected_out,
num_items,
is_even);

out.resize(num_selected_out[0]);
REQUIRE(num_selected == num_selected_out[0]);
Expand Down

0 comments on commit 4c01521

Please sign in to comment.