Skip to content

Commit

Permalink
Add deviceSelect flagged_if API test example
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Mar 25, 2024
1 parent 71ca09b commit 50a4890
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 39 deletions.
46 changes: 10 additions & 36 deletions cub/cub/device/device_select.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -830,43 +830,17 @@ struct DeviceSelect
//!
//! The code snippet below illustrates the compaction of items selected from an ``int`` device vector.
//!
//! .. code-block:: c++
//!
//! #include <cub/cub.cuh> // or equivalently <cub/device/device_select.cuh>
//!
//! struct is_even_t
//! {
//! __host__ __device__ bool operator()(int const& elem) const
//! {
//! return !(elem % 2);
//! }
//! };
//!
//! // Declare, allocate, and initialize device-accessible pointers for input,
//! // flags, and output
//! int num_items; // e.g., 8
//! int *d_data; // e.g., [0, 1, 2, 3, 4, 5, 6, 7]
//! char *d_flags; // e.g., [8, 6, 7, 5, 3, 0, 9, 3]
//! int *d_num_selected_out; // e.g., [ ]
//! ...
//!
//! // Determine temporary device storage requirements
//! void *d_temp_storage = NULL;
//! size_t temp_storage_bytes = 0;
//! cub::DeviceSelect::FlaggedIf(
//! d_temp_storage, temp_storage_bytes,
//! d_in, d_flags, d_num_selected_out, num_items, is_even);
//!
//! // Allocate temporary storage
//! cudaMalloc(&d_temp_storage, temp_storage_bytes);
//!
//! // Run selection
//! cub::DeviceSelect::Flagged(
//! d_temp_storage, temp_storage_bytes,
//! d_in, d_flags, d_num_selected_out, num_items, is_even);
//! .. literalinclude:: ../../test/catch2_test_device_select_api.cu
//! :language: c++
//! :dedent:
//! :start-after: example-begin segmented-select-iseven
//! :end-before: example-end segmented-select-iseven
//!
//! // d_data <-- [0, 1, 5]
//! // d_num_selected_out <-- [3]
//! .. literalinclude:: ../../test/catch2_test_device_select_api.cu
//! :language: c++
//! :dedent:
//! :start-after: example-begin segmented-select-flaggedif-inplace
//! :end-before: example-end segmented-select-flaggedif-inplace
//!
//! @endrst
//!
Expand Down
30 changes: 30 additions & 0 deletions cub/test/catch2_test_device_select_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,33 @@ CUB_TEST("cub::DeviceSelect::FlaggedIf works with int data elements", "[select][
d_out.resize(d_num_selected_out[0]);
REQUIRE(d_out == expected);
}

CUB_TEST("cub::DeviceSelect::FlaggedIf in-place works with int data elements", "[select][device]")
{
// example-begin segmented-select-flaggedif-inplace
constexpr int num_items = 8;
thrust::device_vector<int> d_data = {0, 1, 2, 3, 4, 5, 6, 7};
thrust::device_vector<int> d_flags = {8, 6, 7, 5, 3, 0, 9, 3};
thrust::device_vector<int> d_num_selected_out(num_items);
is_even_t is_even{};

// Determine temporary device storage requirements
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
cub::DeviceSelect::FlaggedIf(
d_temp_storage, temp_storage_bytes, d_data.begin(), d_flags.begin(), d_num_selected_out.data(), num_items, is_even);

// Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);

// Run selection
cub::DeviceSelect::FlaggedIf(
d_temp_storage, temp_storage_bytes, d_data.begin(), d_flags.begin(), d_num_selected_out.data(), num_items, is_even);

thrust::device_vector<int> expected{0, 1, 5};
// example-end segmented-select-flaggedif-inplace

REQUIRE(d_num_selected_out[0] == static_cast<int>(expected.size()));
d_data.resize(d_num_selected_out[0]);
REQUIRE(d_data == expected);
}
6 changes: 3 additions & 3 deletions cub/test/catch2_test_device_select_flagged_if.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ CUB_TEST("DeviceSelect::FlaggedIf does not change input and is stable",
c2h::device_vector<flag_type> flags(num_items);
c2h::gen(CUB_SEED(1), flags);
const c2h::host_vector<input_type> reference_out = get_reference(in, flags, is_even);
const int num_selected = reference_out.size();
const std::size_t num_selected = reference_out.size();

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
Expand Down Expand Up @@ -234,7 +234,7 @@ CUB_TEST("DeviceSelect::FlaggedIf works with iterators", "[device][select_if]",
c2h::device_vector<flag_type> flags(num_items);
c2h::gen(CUB_SEED(1), flags);
const c2h::host_vector<input_type> reference = get_reference(in, flags, is_even);
const int num_selected = reference.size();
const std::size_t num_selected = reference.size();

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
Expand Down Expand Up @@ -263,7 +263,7 @@ CUB_TEST("DeviceSelect::FlaggedIf works with pointers", "[device][select_flagged
c2h::gen(CUB_SEED(1), flags);

const c2h::host_vector<input_type> reference = get_reference(in, flags, is_even);
const int num_selected = reference.size();
const std::size_t num_selected = reference.size();

// Needs to be device accessible
c2h::device_vector<int> num_selected_out(1, 0);
Expand Down

0 comments on commit 50a4890

Please sign in to comment.