Skip to content

Commit

Permalink
fix bug for fp32 batchnorm_op when using nhwc data_layout (#37020) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Dec 29, 2021
1 parent c111340 commit 8ef7102
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/batch_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor transformed_d_y(d_y->type());
Tensor transformed_d_x;
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) {
compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform input tensor from NHWC to NCHW.";
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
&transformed_x);
Expand Down

0 comments on commit 8ef7102

Please sign in to comment.