Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Apr 18, 2024
1 parent 5415daa commit bbb4c49
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include <cub/util_temporary_storage.cuh>

#include <cuda/std/type_traits>
#include <cuda/std/utility>

#include <array>
#include <atomic>
Expand Down
15 changes: 8 additions & 7 deletions cub/cub/util_vsmem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
# pragma system_header
#endif // no system header

#if defined(_CCCL_CUDA_COMPILER)
#include <cub/util_ptx.cuh>
#endif // _CCCL_CUDA_COMPILER
#include <cub/util_type.cuh>

#include <cuda/discard_memory>
Expand Down Expand Up @@ -131,7 +134,7 @@ public:
* passed to the agent, specialized for the case when we can use native shared memory as temporary
* storage and taking a linear block id.
*/
static _CCCL_DEVICE __forceinline__ typename AgentT::TempStorage&
static _CCCL_DEVICE _CCCL_FORCEINLINE typename AgentT::TempStorage&
get_temp_storage(typename AgentT::TempStorage& static_temp_storage, vsmem_t&, std::size_t)
{
return static_temp_storage;
Expand All @@ -155,7 +158,7 @@ public:
* passed to the agent, specialized for the case when we have to use global memory-backed
* virtual shared memory as temporary storage and taking a linear block id.
*/
static _CCCL_DEVICE __forceinline__ typename AgentT::TempStorage&
static _CCCL_DEVICE _CCCL_FORCEINLINE typename AgentT::TempStorage&
get_temp_storage(cub::NullType& static_temp_storage, vsmem_t& vsmem, std::size_t linear_block_id)
{
return *reinterpret_cast<typename AgentT::TempStorage*>(
Expand All @@ -176,6 +179,7 @@ public:
return false;
}

# if defined(_CCCL_CUDA_COMPILER)
/**
* @brief Hints to discard modified cache lines of the used virtual shared memory.
* modified cache lines.
Expand All @@ -186,9 +190,8 @@ public:
template <bool needs_vsmem_ = needs_vsmem, typename ::cuda::std::enable_if<needs_vsmem_, int>::type = 0>
static _CCCL_DEVICE _CCCL_FORCEINLINE bool discard_temp_storage(typename AgentT::TempStorage& temp_storage)
{
# if defined(_CCCL_CUDA_COMPILER)
// Ensure all threads finished using temporary storage
NV_IF_TARGET(NV_IS_HOST, (__syncthreads();))
CTA_SYNC();

const std::size_t linear_tid = threadIdx.x;
const std::size_t block_stride = line_size * blockDim.x;
Expand All @@ -202,10 +205,8 @@ public:
cuda::discard_memory(thread_ptr, line_size);
}
return true;
# else // ^^^ _CCCL_CUDA_COMPILER ^^^ / vvv !_CCCL_CUDA_COMPILER vvv
return false;
# endif // !_CCCL_CUDA_COMPILER
}
# endif // !_CCCL_CUDA_COMPILER
};

template <class DefaultAgentT, class FallbackAgentT>
Expand Down
5 changes: 2 additions & 3 deletions thrust/thrust/system/cuda/detail/copy_if.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
# include <cub/util_temporary_storage.cuh>
# include <cub/util_type.cuh>

# include <thrust/advance.h>
# include <thrust/detail/alignment.h>
# include <thrust/detail/cstdint.h>
# include <thrust/detail/function.h>
Expand Down Expand Up @@ -146,7 +147,6 @@ struct DispatchCopyIf
// Return for empty problems
if (num_items == 0)
{
output;
return status;
}

Expand Down Expand Up @@ -179,8 +179,7 @@ struct DispatchCopyIf
status = cuda_cub::synchronize(policy);
CUDA_CUB_RET_IF_FAIL(status);
OffsetT num_selected = get_value(policy, d_num_selected_out);

output += num_selected;
thrust::advance(output, num_selected);
return status;
}
};
Expand Down

0 comments on commit bbb4c49

Please sign in to comment.