diff --git a/cub/test/test_thread_sort.cu b/cub/test/catch2_test_thread_sort.cu similarity index 73% rename from cub/test/test_thread_sort.cu rename to cub/test/catch2_test_thread_sort.cu index e9f6b800127..e1b521dff73 100644 --- a/cub/test/test_thread_sort.cu +++ b/cub/test/catch2_test_thread_sort.cu @@ -32,8 +32,8 @@ #include #include +#include "catch2_test_helper.h" #include "cub/thread/thread_sort.cuh" -#include "test_util.h" struct CustomLess { @@ -71,30 +71,34 @@ __global__ void kernel(const KeyT* keys_in, KeyT* keys_out, const ValueT* values } } -template -void Test() +using value_types = c2h::type_list; +using items_per_thread_list = c2h::enum_type_list; + +CUB_TEST("Test", "[thread_sort]", value_types, items_per_thread_list) { + using key_t = std::uint32_t; + using value_t = c2h::get<0, TestType>; + constexpr int items_per_thread = c2h::get<1, TestType>::value; constexpr unsigned int threads_in_block = 1024; - constexpr unsigned int elements = threads_in_block * ItemsPerThread; + constexpr unsigned int elements = threads_in_block * items_per_thread; thrust::default_random_engine re; - thrust::device_vector data_source(elements); + c2h::device_vector data_source(elements); for (int iteration = 0; iteration < 10; iteration++) { - thrust::sequence(data_source.begin(), data_source.end()); - thrust::shuffle(data_source.begin(), data_source.end(), re); - thrust::device_vector in_keys(data_source); - thrust::device_vector out_keys(elements); + c2h::gen(CUB_SEED(2), data_source); + c2h::device_vector in_keys(data_source); + c2h::device_vector out_keys(elements); thrust::shuffle(data_source.begin(), data_source.end(), re); - thrust::device_vector in_values(data_source); - thrust::device_vector out_values(elements); + c2h::device_vector in_values(data_source); + c2h::device_vector out_values(elements); - thrust::host_vector host_keys(in_keys); - thrust::host_vector host_values(in_values); + c2h::host_vector host_keys(in_keys); + c2h::host_vector host_values(in_values); - kernel<<<1, threads_in_block>>>( + kernel<<<1, threads_in_block>>>( thrust::raw_pointer_cast(in_keys.data()), thrust::raw_pointer_cast(out_keys.data()), thrust::raw_pointer_cast(in_values.data()), @@ -102,8 +106,8 @@ void Test() for (unsigned int tid = 0; tid < threads_in_block; tid++) { - const auto thread_begin = tid * ItemsPerThread; - const auto thread_end = thread_begin + ItemsPerThread; + const auto thread_begin = tid * items_per_thread; + const auto thread_end = thread_begin + items_per_thread; thrust::sort_by_key(host_keys.begin() + thread_begin, host_keys.begin() + thread_end, @@ -111,28 +115,7 @@ void Test() CustomLess{}); } - AssertEquals(host_keys, out_keys); - AssertEquals(host_values, out_values); + CHECK(host_keys == out_keys); + CHECK(host_values == out_values); } } - -template -void Test() -{ - Test(); - Test(); - Test(); - Test(); - Test(); - Test(); - Test(); - Test(); -} - -int main() -{ - Test(); - Test(); - - return 0; -}