Skip to content

Commit

Permalink
Merge pull request #380 from elstehle/enh/catch2-device-scan-by-key
Browse files Browse the repository at this point in the history
Ports cub::DeviceScanByKey tests to Catch2
  • Loading branch information
elstehle committed Sep 3, 2023
2 parents c091d56 + 4eabd27 commit eea703e
Show file tree
Hide file tree
Showing 5 changed files with 864 additions and 1,107 deletions.
17 changes: 17 additions & 0 deletions cub/test/c2h/generators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,26 @@ void init_key_segments(const thrust::device_vector<OffsetT> &segment_offsets,
template void init_key_segments(const thrust::device_vector<std::uint32_t> &segment_offsets,
std::int32_t *out,
std::size_t element_size);
template void init_key_segments(const thrust::device_vector<std::uint32_t> &segment_offsets,
std::uint8_t *out,
std::size_t element_size);
template void init_key_segments(const thrust::device_vector<std::uint32_t> &segment_offsets,
float *out,
std::size_t element_size);
template void init_key_segments(const thrust::device_vector<std::uint32_t> &segment_offsets,
custom_type_state_t *out,
std::size_t element_size);
#ifdef TEST_HALF_T
template void init_key_segments(const thrust::device_vector<std::uint32_t> &segment_offsets,
half_t *out,
std::size_t element_size);
#endif

#ifdef TEST_BF_T
template void init_key_segments(const thrust::device_vector<std::uint32_t> &segment_offsets,
bfloat16_t *out,
std::size_t element_size);
#endif
} // namespace detail

template <typename T>
Expand Down
173 changes: 165 additions & 8 deletions cub/test/catch2_test_device_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,189 @@

#pragma once

template <typename InputIt, typename OutputIt, typename T, typename BinaryOp>
#include <cub/detail/type_traits.cuh>
#include <cub/thread/thread_operators.cuh>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

/**
* @brief Helper class template to facilitate specifying input/output type pairs along with the key
* type for *-by-key algorithms, and an equality operator type.
*/
template <typename InputT,
typename OutputT = InputT,
typename KeyT = std::int32_t,
typename EqualityOpT = cub::Equality>
struct type_quad
{
using input_t = InputT;
using output_t = OutputT;
using key_t = KeyT;
using eq_op_t = EqualityOpT;
};

/**
* @brief Mod2Equality (used for integral keys, making keys more likely to equal each other)
*/
struct Mod2Equality
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
{
return (a % 2) == (b % 2);
}
};

template <typename InputIt, typename OutputIt, typename InitT, typename BinaryOp>
void compute_exclusive_scan_reference(InputIt first,
InputIt last,
OutputIt result,
T init,
InitT init,
BinaryOp op)
{
T acc = init;
using value_t = cub::detail::value_t<InputIt>;
using accum_t = cub::detail::accumulator_t<BinaryOp, InitT, value_t>;
using output_t = cub::detail::value_t<OutputIt>;
accum_t acc = static_cast<accum_t>(init);
for (; first != last; ++first)
{
*result++ = acc;
*result++ = static_cast<output_t>(acc);
acc = op(acc, *first);
}
}

template <typename InputIt, typename OutputIt, typename T, typename BinaryOp>
template <typename InputIt, typename OutputIt, typename BinaryOp, typename InitT>
void compute_inclusive_scan_reference(InputIt first,
InputIt last,
OutputIt result,
BinaryOp op,
T init)
InitT init)
{
T acc = init;
using value_t = cub::detail::value_t<InputIt>;
using accum_t = cub::detail::accumulator_t<BinaryOp, InitT, value_t>;
using output_t = cub::detail::value_t<OutputIt>;
accum_t acc = static_cast<accum_t>(init);
for (; first != last; ++first)
{
acc = op(acc, *first);
*result++ = acc;
*result++ = static_cast<output_t>(acc);
}
}

template <typename ValueInItT,
typename KeyInItT,
typename ValuesOutItT,
typename ScanOpT,
typename EqualityOpT,
typename InitT>
void compute_exclusive_scan_by_key_reference(ValueInItT h_values_it,
KeyInItT h_keys_it,
ValuesOutItT result_out_it,
ScanOpT scan_op,
EqualityOpT equality_op,
InitT init,
std::size_t num_items)
{
using value_t = cub::detail::value_t<ValueInItT>;
using accum_t = cub::detail::accumulator_t<ScanOpT, InitT, value_t>;
using output_t = cub::detail::value_t<ValuesOutItT>;

if (num_items > 0)
{
for (std::size_t i = 0; i < num_items;)
{
accum_t val = static_cast<accum_t>(h_values_it[i]);
result_out_it[i] = init;
accum_t inclusive = static_cast<accum_t>(scan_op(init, val));

++i;

for (; i < num_items && equality_op(h_keys_it[i - 1], h_keys_it[i]); ++i)
{
val = static_cast<accum_t>(h_values_it[i]);
result_out_it[i] = static_cast<output_t>(inclusive);
inclusive = static_cast<accum_t>(scan_op(inclusive, val));
}
}
}
}

template <typename ValueT,
typename KeyT,
typename ValuesOutItT,
typename ScanOpT,
typename EqualityOpT,
typename InitT>
void compute_exclusive_scan_by_key_reference(const thrust::device_vector<ValueT> &d_values,
const thrust::device_vector<KeyT> &d_keys,
ValuesOutItT result_out_it,
ScanOpT scan_op,
EqualityOpT equality_op,
InitT init)
{
thrust::host_vector<ValueT> host_values(d_values);
thrust::host_vector<KeyT> host_keys(d_keys);

std::size_t num_items = host_values.size();

compute_exclusive_scan_by_key_reference(host_values.cbegin(),
host_keys.cbegin(),
result_out_it,
scan_op,
equality_op,
init,
num_items);
}

template <typename ValueInItT,
typename KeyInItT,
typename ValuesOutItT,
typename ScanOpT,
typename EqualityOpT>
void compute_inclusive_scan_by_key_reference(ValueInItT h_values_it,
KeyInItT h_keys_it,
ValuesOutItT result_out_it,
ScanOpT scan_op,
EqualityOpT equality_op,
std::size_t num_items)
{
using value_t = cub::detail::value_t<ValueInItT>;
using accum_t = cub::detail::accumulator_t<ScanOpT, value_t, value_t>;
using output_t = cub::detail::value_t<ValuesOutItT>;

for (std::size_t i = 0; i < num_items;)
{
accum_t inclusive = h_values_it[i];
result_out_it[i] = static_cast<output_t>(inclusive);

++i;

for (; i < num_items && equality_op(h_keys_it[i - 1], h_keys_it[i]); ++i)
{
accum_t val = h_values_it[i];
inclusive = static_cast<accum_t>(scan_op(inclusive, val));
result_out_it[i] = static_cast<output_t>(inclusive);
}
}
}

template <typename ValueT, typename KeyT, typename ValuesOutItT, typename ScanOpT, typename EqualityOpT>
void compute_inclusive_scan_by_key_reference(const thrust::device_vector<ValueT> &d_values,
const thrust::device_vector<KeyT> &d_keys,
ValuesOutItT result_out_it,
ScanOpT scan_op,
EqualityOpT equality_op)
{
thrust::host_vector<ValueT> host_values(d_values);
thrust::host_vector<KeyT> host_keys(d_keys);

std::size_t num_items = host_values.size();

compute_inclusive_scan_by_key_reference(host_values.cbegin(),
host_keys.cbegin(),
result_out_it,
scan_op,
equality_op,
num_items);
}
Loading

0 comments on commit eea703e

Please sign in to comment.