Skip to content

Commit

Permalink
Add tests for value based block scan APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed May 23, 2024
1 parent ea1fbe8 commit 37e726e
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions cub/test/catch2_test_block_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ __global__ void block_scan_kernel(T* in, T* out, ActionT action)
}
}

template <cub::BlockScanAlgorithm Algorithm, int BlockDimX, int BlockDimY, int BlockDimZ, class T, class ActionT>
__global__ void block_scan_value_based_kernel(T* in, T* out, ActionT action)
{
using block_scan_t = cub::BlockScan<T, BlockDimX, Algorithm, BlockDimY, BlockDimZ>;
using storage_t = typename block_scan_t::TempStorage;

__shared__ storage_t storage;

T thread_data;

const int tid = static_cast<int>(cub::RowMajorTid(BlockDimX, BlockDimY, BlockDimZ));

thread_data = in[tid];

block_scan_t scan(storage);

action(scan, thread_data);
out[tid] = thread_data;
}

template <cub::BlockScanAlgorithm Algorithm,
int ItemsPerThread,
int BlockDimX,
Expand Down Expand Up @@ -142,6 +162,42 @@ struct min_op_t
}
};

template <scan_mode Mode>
struct min_op_value_t
{
template <class BlockScanT>
__device__ void operator()(BlockScanT& scan, int& thread_data) const
{
if (Mode == scan_mode::exclusive)
{
scan.ExclusiveScan(thread_data, thread_data, cub::Min{});
}
else
{
scan.InclusiveScan(thread_data, thread_data, cub::Min{});
}
}
};

template <class T, scan_mode Mode>
struct min_op_value_init_t
{
T initial_value;

template <class BlockScanT>
__device__ void operator()(BlockScanT& scan, int& thread_data) const
{
if (Mode == scan_mode::exclusive)
{
scan.ExclusiveScan(thread_data, thread_data, initial_value, cub::Min{});
}
else
{
scan.InclusiveScan(thread_data, thread_data, initial_value, cub::Min{});
}
}
};

template <class T, scan_mode Mode>
struct min_init_value_aggregate_op_t
{
Expand Down Expand Up @@ -513,6 +569,86 @@ CUB_TEST("Block scan supports custom scan op", "[scan][block]", algorithm, modes
REQUIRE(h_out == d_out);
}

CUB_TEST("Block scan value based overload works", "[scan][block]", algorithm, modes, block_dim_yz)
{
constexpr int items_per_thread = 1;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = c2h::get<2, TestType>::value;
constexpr int block_dim_z = block_dim_y;
constexpr int threads_in_block = block_dim_x * block_dim_y * block_dim_z;
constexpr int tile_size = items_per_thread * threads_in_block;
constexpr cub::BlockScanAlgorithm algorithm = c2h::get<0, TestType>::value;
constexpr scan_mode mode = c2h::get<1, TestType>::value;

using type = int;

c2h::device_vector<type> d_out(tile_size);
c2h::device_vector<type> d_in(tile_size);
c2h::gen(CUB_SEED(10), d_in);
d_in[0] = INT_MIN;

dim3 block_dims(block_dim_x, block_dim_y, block_dim_z);

block_scan_value_based_kernel<algorithm, block_dim_x, block_dim_y, block_dim_z, type, decltype(min_op_value_t<mode>{})>
<<<1, block_dims>>>(
thrust::raw_pointer_cast(d_in.data()), thrust::raw_pointer_cast(d_out.data()), min_op_value_t<mode>{});

c2h::host_vector<type> h_out = d_in;
host_scan(
mode,
h_out,
[](type l, type r) {
return std::min(l, r);
},
INT_MIN);

REQUIRE(h_out == d_out);
}

CUB_TEST("Block scan value based overload works with initial value", "[scan][block]", algorithm, modes, block_dim_yz)
{
constexpr int items_per_thread = 1;
constexpr int block_dim_x = 64;
constexpr int block_dim_y = c2h::get<2, TestType>::value;
constexpr int block_dim_z = block_dim_y;
constexpr int threads_in_block = block_dim_x * block_dim_y * block_dim_z;
constexpr int tile_size = items_per_thread * threads_in_block;
constexpr cub::BlockScanAlgorithm algorithm = c2h::get<0, TestType>::value;
constexpr scan_mode mode = c2h::get<1, TestType>::value;

using type = int;

c2h::device_vector<type> d_out(tile_size);
c2h::device_vector<type> d_in(tile_size);
c2h::gen(CUB_SEED(10), d_in);
d_in[0] = INT_MIN;

const type initial_value = static_cast<type>(GENERATE_COPY(take(2, random(0, tile_size))));

dim3 block_dims(block_dim_x, block_dim_y, block_dim_z);

block_scan_value_based_kernel<algorithm,
block_dim_x,
block_dim_y,
block_dim_z,
type,
decltype(min_op_value_init_t<type, mode>{})><<<1, block_dims>>>(
thrust::raw_pointer_cast(d_in.data()),
thrust::raw_pointer_cast(d_out.data()),
min_op_value_init_t<type, mode>{initial_value});

c2h::host_vector<type> h_out = d_in;
host_scan(
mode,
h_out,
[](type l, type r) {
return std::min(l, r);
},
initial_value);

REQUIRE(h_out == d_out);
}

CUB_TEST("Block custom op scan works with initial value", "[scan][block]", algorithm, modes, block_dim_yz)
{
constexpr int items_per_thread = 3;
Expand Down

0 comments on commit 37e726e

Please sign in to comment.