diff --git a/cub/cub/device/dispatch/dispatch_unique_by_key.cuh b/cub/cub/device/dispatch/dispatch_unique_by_key.cuh index 3e11394c29..4c051b68e9 100644 --- a/cub/cub/device/dispatch/dispatch_unique_by_key.cuh +++ b/cub/cub/device/dispatch/dispatch_unique_by_key.cuh @@ -45,8 +45,8 @@ #include #include #include -#include #include +#include #include #include @@ -112,6 +112,9 @@ CUB_NAMESPACE_BEGIN * * @param[in] num_tiles * Total number of tiles for the entire problem + * + * @param[in] vsmem + * Memory to support virtual shared memory */ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT::BLOCK_THREADS)) - CUB_DETAIL_KERNEL_ATTRIBUTES - void DeviceUniqueByKeySweepKernel(KeyInputIteratorT d_keys_in, - ValueInputIteratorT d_values_in, - KeyOutputIteratorT d_keys_out, - ValueOutputIteratorT d_values_out, - NumSelectedIteratorT d_num_selected_out, - ScanTileStateT tile_state, - EqualityOpT equality_op, - OffsetT num_items, - int num_tiles) +__launch_bounds__(int( + cub::detail::vsmem_helper_default_fallback_policy_t< + typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT, + AgentUniqueByKey, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + EqualityOpT, + OffsetT>::agent_policy_t::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceUniqueByKeySweepKernel( + KeyInputIteratorT d_keys_in, + ValueInputIteratorT d_values_in, + KeyOutputIteratorT d_keys_out, + ValueOutputIteratorT d_values_out, + NumSelectedIteratorT d_num_selected_out, + ScanTileStateT tile_state, + EqualityOpT equality_op, + OffsetT num_items, + int num_tiles, + cub::detail::vsmem_t vsmem) { - using AgentUniqueByKeyPolicyT = typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT; - - // Thread block type for selecting data from input tiles - using AgentUniqueByKeyT = AgentUniqueByKey; - - // Shared memory for AgentUniqueByKey - __shared__ typename AgentUniqueByKeyT::TempStorage temp_storage; - - // Process tiles - AgentUniqueByKeyT(temp_storage, d_keys_in, d_values_in, d_keys_out, d_values_out, equality_op, num_items).ConsumeRange( - num_tiles, - tile_state, - d_num_selected_out); + using VsmemHelperT = cub::detail::vsmem_helper_default_fallback_policy_t< + typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT, + AgentUniqueByKey, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + EqualityOpT, + OffsetT>; + + using AgentUniqueByKeyPolicyT = typename VsmemHelperT::agent_policy_t; + + // Thread block type for selecting data from input tiles + using AgentUniqueByKeyT = typename VsmemHelperT::agent_t; + + // Static shared memory allocation + __shared__ typename VsmemHelperT::static_temp_storage_t static_temp_storage; + + // Get temporary storage + typename AgentUniqueByKeyT::TempStorage& temp_storage = + VsmemHelperT::get_temp_storage(static_temp_storage, vsmem, (blockIdx.x * gridDim.y) + blockIdx.y); + + // Process tiles + AgentUniqueByKeyT(temp_storage, d_keys_in, d_values_in, d_keys_out, d_values_out, equality_op, num_items) + .ConsumeRange(num_tiles, tile_state, d_num_selected_out); + + // If applicable, hints to discard modified cache lines for vsmem + VsmemHelperT::discard_temp_storage(temp_storage); } - /****************************************************************************** * Dispatch ******************************************************************************/ @@ -338,13 +359,16 @@ struct DispatchUniqueByKey : SelectedPolicy cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel) { using Policy = typename ActivePolicyT::UniqueByKeyPolicyT; - using UniqueByKeyAgentT = AgentUniqueByKey; + + using VsmemHelperT = cub::detail::vsmem_helper_default_fallback_policy_t< + Policy, + AgentUniqueByKey, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + EqualityOpT, + OffsetT>; cudaError error = cudaSuccess; do @@ -358,23 +382,14 @@ struct DispatchUniqueByKey : SelectedPolicy } // Number of input tiles - int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; - int num_tiles = static_cast(cub::DivideAndRoundUp(num_items, tile_size)); - - // Size of virtual shared memory - int max_shmem = 0; - - error = CubDebug(cudaDeviceGetAttribute(&max_shmem, - cudaDevAttrMaxSharedMemoryPerBlock, - device_ordinal)); - if (cudaSuccess != error) - { - break; - } - std::size_t vshmem_size = detail::VshmemSize(max_shmem, sizeof(typename UniqueByKeyAgentT::TempStorage), num_tiles); + constexpr auto block_threads = VsmemHelperT::agent_policy_t::BLOCK_THREADS; + constexpr auto items_per_thread = VsmemHelperT::agent_policy_t::ITEMS_PER_THREAD; + int tile_size = block_threads * items_per_thread; + int num_tiles = static_cast(cub::DivideAndRoundUp(num_items, tile_size)); + const auto vsmem_size = num_tiles * VsmemHelperT::vsmem_per_block; // Specify temporary storage allocation requirements - size_t allocation_sizes[2] = {0, vshmem_size}; + size_t allocation_sizes[2] = {0, vsmem_size}; // Bytes needed for tile status descriptors error = CubDebug(ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0])); @@ -459,7 +474,7 @@ struct DispatchUniqueByKey : SelectedPolicy int scan_sm_occupancy; error = CubDebug(MaxSmOccupancy(scan_sm_occupancy, // out scan_kernel, - Policy::BLOCK_THREADS)); + block_threads)); if (cudaSuccess != error) { break; @@ -470,26 +485,27 @@ struct DispatchUniqueByKey : SelectedPolicy scan_grid_size.x, scan_grid_size.y, scan_grid_size.z, - Policy::BLOCK_THREADS, + block_threads, (long long)stream, - Policy::ITEMS_PER_THREAD, + items_per_thread, scan_sm_occupancy); } #endif // Invoke select_if_kernel - error = THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - scan_grid_size, Policy::BLOCK_THREADS, 0, stream - ).doit(scan_kernel, - d_keys_in, - d_values_in, - d_keys_out, - d_values_out, - d_num_selected_out, - tile_state, - equality_op, - num_items, - num_tiles); + error = + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(scan_grid_size, block_threads, 0, stream) + .doit(scan_kernel, + d_keys_in, + d_values_in, + d_keys_out, + d_values_out, + d_num_selected_out, + tile_state, + equality_op, + num_items, + num_tiles, + cub::detail::vsmem_t{allocations[1]}); // Check for failure to launch error = CubDebug(error); diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index c933338de7..a981d22c15 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -136,6 +136,17 @@ public: return static_temp_storage; } + /** + * @brief Used from within the device algorithm's kernel to get the temporary storage that can be + * passed to the agent, specialized for the case when we can use native shared memory as temporary + * storage and taking a linear block id. + */ + static __device__ __forceinline__ typename AgentT::TempStorage& + get_temp_storage(typename AgentT::TempStorage& static_temp_storage, vsmem_t&, std::size_t) + { + return static_temp_storage; + } + /** * @brief Used from within the device algorithm's kernel to get the temporary storage that can be * passed to the agent, specialized for the case when we have to use global memory-backed @@ -148,6 +159,18 @@ public: static_cast(vsmem.gmem_ptr) + (vsmem_per_block * blockIdx.x)); } + /** + * @brief Used from within the device algorithm's kernel to get the temporary storage that can be + * passed to the agent, specialized for the case when we have to use global memory-backed + * virtual shared memory as temporary storage and taking a linear block id. + */ + static __device__ __forceinline__ typename AgentT::TempStorage& + get_temp_storage(cub::NullType& static_temp_storage, vsmem_t& vsmem, std::size_t linear_block_id) + { + return *reinterpret_cast( + static_cast(vsmem.gmem_ptr) + (vsmem_per_block * linear_block_id)); + } + /** * @brief Hints to discard modified cache lines of the used virtual shared memory. * modified cache lines. diff --git a/cub/test/c2h/custom_type.cuh b/cub/test/c2h/custom_type.cuh index 2a1b0c44fd..1b16e70ad2 100644 --- a/cub/test/c2h/custom_type.cuh +++ b/cub/test/c2h/custom_type.cuh @@ -56,6 +56,17 @@ public: }; +template +struct huge_data +{ + template + class type + { + static constexpr auto extra_member_bytes = (TotalSize - sizeof(custom_type_state_t)); + std::uint8_t data[extra_member_bytes]; + }; +}; + template class less_comparable_t { diff --git a/cub/test/catch2_test_device_select_unique_by_key.cu b/cub/test/catch2_test_device_select_unique_by_key.cu index 313b7c9709..d829b9dc12 100644 --- a/cub/test/catch2_test_device_select_unique_by_key.cu +++ b/cub/test/catch2_test_device_select_unique_by_key.cu @@ -27,16 +27,18 @@ #include +#include #include #include #include #include +#include #include #include -#include "catch2_test_launch_helper.h" #include "catch2_test_helper.h" +#include "catch2_test_launch_helper.h" template inline T to_bound(const unsigned long long bound) { @@ -66,6 +68,19 @@ inline c2h::custom_type_t to_bound(const unsigned long return val; } +template +struct index_to_huge_type_op_t +{ + template + __device__ __host__ HugeDataTypeT operator()(const ValueType& val) + { + HugeDataTypeT return_val{}; + return_val.key = val; + return_val.val = val; + return return_val; + } +}; + DECLARE_LAUNCH_WRAPPER(cub::DeviceSelect::UniqueByKey, select_unique_by_key); // %PARAM% TEST_LAUNCH lid 0:1:2 @@ -86,6 +101,9 @@ using all_types = c2h::type_list>; +using huge_types = c2h::type_list::type>, + c2h::custom_type_t::type>>; + using types = c2h::type_list; @@ -338,3 +356,49 @@ CUB_TEST("DeviceSelect::UniqueByKey works with a different output type", "[devic REQUIRE(reference_keys == keys_out); REQUIRE(reference_vals == vals_out); } + +CUB_TEST("DeviceSelect::UniqueByKey works and uses vsmem for large types", + "[device][select_unique_by_key][vsmem]", + huge_types) +{ + using type = std::uint32_t; + using val_type = typename c2h::get<0, TestType>; + + const int num_items = GENERATE_COPY(take(2, random(1, 100000))); + thrust::device_vector keys_in(num_items); + thrust::device_vector keys_out(num_items); + thrust::device_vector vals_out(num_items); + c2h::gen(CUB_SEED(2), keys_in, to_bound(0), to_bound(42)); + + auto vals_it = + thrust::make_transform_iterator(thrust::make_counting_iterator(0U), index_to_huge_type_op_t{}); + + // Needs to be device accessible + thrust::device_vector num_selected_out(1, 0); + int* d_first_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data()); + + select_unique_by_key( + thrust::raw_pointer_cast(keys_in.data()), + vals_it, + thrust::raw_pointer_cast(keys_out.data()), + thrust::raw_pointer_cast(vals_out.data()), + d_first_num_selected_out, + num_items); + + // Ensure that we create the same output as std + thrust::host_vector reference_keys = keys_in; + thrust::host_vector reference_vals(num_items); + thrust::copy(vals_it, vals_it + num_items, reference_vals.begin()); + + const auto zip_begin = thrust::make_zip_iterator(reference_keys.begin(), reference_vals.begin()); + const auto zip_end = thrust::make_zip_iterator(reference_keys.end(), reference_vals.end()); + const auto boundary = std::unique(zip_begin, zip_end, project_first{}); + REQUIRE((boundary - zip_begin) == num_selected_out[0]); + + keys_out.resize(num_selected_out[0]); + vals_out.resize(num_selected_out[0]); + reference_keys.resize(num_selected_out[0]); + reference_vals.resize(num_selected_out[0]); + REQUIRE(reference_keys == keys_out); + REQUIRE(reference_vals == vals_out); +}