Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues that came up with building cuDF with main #1643

Merged
merged 9 commits into from
May 6, 2024
Merged
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_namespace.cuh>
#include <cub/util_vsmem.cuh>

#include <thrust/detail/integer_math.h>
#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>
Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include <cub/util_deprecated.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_vsmem.cuh>

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <cub/util_deprecated.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_vsmem.cuh>

#include <iterator>

Expand Down
196 changes: 0 additions & 196 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,10 @@

#include <cub/detail/device_synchronize.cuh>
#include <cub/util_debug.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
// for backward compatibility
#include <cub/util_temporary_storage.cuh>

#include <cuda/discard_memory>
miscco marked this conversation as resolved.
Show resolved Hide resolved
#include <cuda/std/type_traits>
#include <cuda/std/utility>
miscco marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -70,7 +68,6 @@ CUB_NAMESPACE_BEGIN

namespace detail
{

/**
* @brief Helper class template that allows overwriting the `BLOCK_THREAD` and `ITEMS_PER_THREAD`
* configurations of a given policy.
Expand All @@ -82,199 +79,6 @@ struct policy_wrapper_t : PolicyT
static constexpr int BLOCK_THREADS = BLOCK_THREADS_;
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
};

/**
* @brief Helper struct to wrap all the information needed to implement virtual shared memory that's passed to a kernel.
*
*/
struct vsmem_t
{
void* gmem_ptr;
};

// The maximum amount of static shared memory available per thread block
// Note that in contrast to dynamic shared memory, static shared memory is still limited to 48 KB
static constexpr std::size_t max_smem_per_block = 48 * 1024;

/**
* @brief Class template that helps to prevent exceeding the available shared memory per thread block.
*
* @tparam AgentT The agent for which we check whether per-thread block shared memory is sufficient or whether virtual
* shared memory is needed.
*/
template <typename AgentT>
class vsmem_helper_impl
{
private:
// Per-block virtual shared memory may be padded to make sure vsmem is an integer multiple of `line_size`
static constexpr std::size_t line_size = 128;

// The amount of shared memory or virtual shared memory required by the algorithm's agent
static constexpr std::size_t required_smem = sizeof(typename AgentT::TempStorage);

// Whether we need to allocate global memory-backed virtual shared memory
static constexpr bool needs_vsmem = required_smem > max_smem_per_block;

// Padding bytes to an integer multiple of `line_size`. Only applies to virtual shared memory
static constexpr std::size_t padding_bytes =
(required_smem % line_size == 0) ? 0 : (line_size - (required_smem % line_size));

public:
// Type alias to be used for static temporary storage declaration within the algorithm's kernel
using static_temp_storage_t = cub::detail::conditional_t<needs_vsmem, cub::NullType, typename AgentT::TempStorage>;

// The amount of global memory-backed virtual shared memory needed, padded to an integer multiple of 128 bytes
static constexpr std::size_t vsmem_per_block = needs_vsmem ? (required_smem + padding_bytes) : 0;

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we can use native shared memory as temporary
* storage.
*/
static _CCCL_DEVICE _CCCL_FORCEINLINE typename AgentT::TempStorage&
get_temp_storage(typename AgentT::TempStorage& static_temp_storage, vsmem_t&)
{
return static_temp_storage;
}

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we can use native shared memory as temporary
* storage and taking a linear block id.
*/
static __device__ __forceinline__ typename AgentT::TempStorage&
get_temp_storage(typename AgentT::TempStorage& static_temp_storage, vsmem_t&, std::size_t)
{
return static_temp_storage;
}

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we have to use global memory-backed
* virtual shared memory as temporary storage.
*/
static _CCCL_DEVICE _CCCL_FORCEINLINE typename AgentT::TempStorage&
get_temp_storage(cub::NullType& static_temp_storage, vsmem_t& vsmem)
{
return *reinterpret_cast<typename AgentT::TempStorage*>(
static_cast<char*>(vsmem.gmem_ptr) + (vsmem_per_block * blockIdx.x));
}

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we have to use global memory-backed
* virtual shared memory as temporary storage and taking a linear block id.
*/
static __device__ __forceinline__ typename AgentT::TempStorage&
get_temp_storage(cub::NullType& static_temp_storage, vsmem_t& vsmem, std::size_t linear_block_id)
{
return *reinterpret_cast<typename AgentT::TempStorage*>(
static_cast<char*>(vsmem.gmem_ptr) + (vsmem_per_block * linear_block_id));
}

/**
* @brief Hints to discard modified cache lines of the used virtual shared memory.
* modified cache lines.
*
* @note Needs to be followed by `__syncthreads()` if the function returns true and the virtual shared memory is
* supposed to be reused after this function call.
*/
template <bool needs_vsmem_ = needs_vsmem, typename ::cuda::std::enable_if<!needs_vsmem_, int>::type = 0>
static _CCCL_DEVICE _CCCL_FORCEINLINE bool discard_temp_storage(typename AgentT::TempStorage& temp_storage)
{
return false;
}

/**
* @brief Hints to discard modified cache lines of the used virtual shared memory.
* modified cache lines.
*
* @note Needs to be followed by `__syncthreads()` if the function returns true and the virtual shared memory is
* supposed to be reused after this function call.
*/
template <bool needs_vsmem_ = needs_vsmem, typename ::cuda::std::enable_if<needs_vsmem_, int>::type = 0>
static _CCCL_DEVICE _CCCL_FORCEINLINE bool discard_temp_storage(typename AgentT::TempStorage& temp_storage)
{
// Ensure all threads finished using temporary storage
CTA_SYNC();

const std::size_t linear_tid = threadIdx.x;
const std::size_t block_stride = line_size * blockDim.x;

char* ptr = reinterpret_cast<char*>(&temp_storage);
auto ptr_end = ptr + vsmem_per_block;

// 128 byte-aligned virtual shared memory discard
for (auto thread_ptr = ptr + (linear_tid * line_size); thread_ptr < ptr_end; thread_ptr += block_stride)
{
cuda::discard_memory(thread_ptr, line_size);
}

return true;
}
};

template <class DefaultAgentT, class FallbackAgentT>
constexpr bool use_fallback_agent()
{
return (sizeof(typename DefaultAgentT::TempStorage) > max_smem_per_block)
&& (sizeof(typename FallbackAgentT::TempStorage) <= max_smem_per_block);
}

/**
* @brief Class template that helps to prevent exceeding the available shared memory per thread block with two measures:
* (1) If an agent's `TempStorage` declaration exceeds the maximum amount of shared memory per thread block, we check
* whether using a fallback policy, e.g., with a smaller tile size, would fit into shared memory.
* (2) If the fallback still doesn't fit into shared memory, we make use of virtual shared memory that is backed by
* global memory.
*
* @tparam DefaultAgentPolicyT The default tuning policy that is used if the default agent's shared memory requirements
* fall within the bounds of `max_smem_per_block` or when virtual shared memory is needed
* @tparam DefaultAgentT The default agent, instantiated with the given default tuning policy
* @tparam FallbackAgentPolicyT A fallback tuning policy that may exhibit lower shared memory requirements, e.g., by
* using a smaller tile size, than the default. This fallback policy is used if and only if the shared memory
* requirements of the default agent exceed `max_smem_per_block`, yet the shared memory requirements of the fallback
* agent falls within the bounds of `max_smem_per_block`.
* @tparam FallbackAgentT The fallback agent, instantiated with the given fallback tuning policy
*/
template <typename DefaultAgentPolicyT,
typename DefaultAgentT,
typename FallbackAgentPolicyT = DefaultAgentPolicyT,
typename FallbackAgentT = DefaultAgentT,
bool UseFallbackPolicy = use_fallback_agent<DefaultAgentT, FallbackAgentT>()>
struct vsmem_helper_with_fallback_impl : public vsmem_helper_impl<DefaultAgentT>
{
using agent_t = DefaultAgentT;
using agent_policy_t = DefaultAgentPolicyT;
};
template <typename DefaultAgentPolicyT, typename DefaultAgentT, typename FallbackAgentPolicyT, typename FallbackAgentT>
struct vsmem_helper_with_fallback_impl<DefaultAgentPolicyT, DefaultAgentT, FallbackAgentPolicyT, FallbackAgentT, true>
: public vsmem_helper_impl<FallbackAgentT>
{
using agent_t = FallbackAgentT;
using agent_policy_t = FallbackAgentPolicyT;
};

/**
* @brief Alias template for the `vsmem_helper_with_fallback_impl` that instantiates the given AgentT template with the
* respective policy as first template parameter, followed by the parameters captured by the `AgentParamsT` template
* parameter pack.
*/
template <typename DefaultPolicyT, typename FallbackPolicyT, template <typename...> class AgentT, typename... AgentParamsT>
using vsmem_helper_fallback_policy_t =
vsmem_helper_with_fallback_impl<DefaultPolicyT,
AgentT<DefaultPolicyT, AgentParamsT...>,
FallbackPolicyT,
AgentT<FallbackPolicyT, AgentParamsT...>>;

/**
* @brief Alias template for the `vsmem_helper_t` by using a simple fallback policy that uses `DefaultPolicyT` as basis,
* overwriting `64` threads per block and `1` item per thread.
*/
template <typename DefaultPolicyT, template <typename...> class AgentT, typename... AgentParamsT>
using vsmem_helper_default_fallback_policy_t =
vsmem_helper_fallback_policy_t<DefaultPolicyT, policy_wrapper_t<DefaultPolicyT, 64, 1>, AgentT, AgentParamsT...>;

} // namespace detail

/**
Expand Down
Loading
Loading