Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stablize depthwise conv #35161

Merged
merged 3 commits into from
Sep 1, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions paddle/fluid/operators/math/depthwise_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,43 @@ namespace operators {
namespace math {

template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
val = WarpReduce(temp_storage).Sum(val, warp_size);
return val;
}

#ifdef __HIPCC__
int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
value = WarpReduce(temp_storage).Sum(value, block_size);
#else
value = WarpReduce(temp_storage).Sum(value);
#endif
template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) {
static __shared__ T shared[32];
int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y;
int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
int lane = thread_id % warp_size;
int wid = thread_id / warp_size;

val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction

if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions

// read from shared memory only if that warp existed
int block_size = blockDim.x * blockDim.y * blockDim.z;
if (thread_id < (block_size - 1) / warp_size + 1) {
val = shared[lane];
} else {
val = static_cast<T>(0);
}

if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
if (wid == 0) {
val = WarpReduceSum(val, warp_size); // Final reduce within first warp
}
__syncthreads();
if (thread_id != 0) {
val = static_cast<T>(0);
}
return val;
}

#define ARG_DEFINE_KernelDepthwiseConv \
Expand Down Expand Up @@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
}
}
}
CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);

T val = BlockReduceSum(s);
platform::CudaAtomicAdd(&filter_grad_data[gbid], val);
}

template <typename T, bool fuse_relu_before_conv>
Expand Down Expand Up @@ -892,6 +919,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
int blocks;
dim3 threads;
dim3 grid;

if (data_layout != DataLayout::kNHWC) {
if (output_width > 1024 && output_width <= 2048)
thread = (output_width - 1) / 2 + 1;
Expand Down Expand Up @@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
int blocks;
dim3 threads;
dim3 grid;

if (data_layout != DataLayout::kNHWC) {
if (input_width > 1024 && input_width <= 2048) {
thread = (input_width - 1) / 2 + 1;
Expand Down