diff --git a/paddle/fluid/operators/diag_v2_op.cc b/paddle/fluid/operators/diag_v2_op.cc index 93fbff67e220b..ac8c12bcd7eba 100644 --- a/paddle/fluid/operators/diag_v2_op.cc +++ b/paddle/fluid/operators/diag_v2_op.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/infermeta/unary.h" @@ -58,15 +56,56 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker { } }; +class DiagV2GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "X", "X", "DiagV2Grad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "DiagV2Grad"); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class DiagV2GradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("diag_v2_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagGradV2NoNeedBufferVarsInferer, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; + DECLARE_INFER_SHAPE_FUNCTOR(diag_v2, DiagInferShapeFunctor, PD_INFER_META(phi::DiagInferMeta)); -REGISTER_OPERATOR( - diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - DiagInferShapeFunctor); +REGISTER_OPERATOR(diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker, + ops::DiagV2GradOpMaker, + ops::DiagV2GradOpMaker, + DiagInferShapeFunctor); + +REGISTER_OPERATOR(diag_v2_grad, ops::DiagV2GradOp, + ops::DiagGradV2NoNeedBufferVarsInferer); diff --git a/paddle/phi/kernels/cpu/diag_grad_kernel.cc b/paddle/phi/kernels/cpu/diag_grad_kernel.cc new file mode 100644 index 0000000000000..c56b225e2a753 --- /dev/null +++ b/paddle/phi/kernels/cpu/diag_grad_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/diag_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void DiagGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int offset, + DenseTensor* x_grad) { + T* dx_data = dev_ctx.template Alloc(x_grad); + const T* dout_data = out_grad.data(); + auto dx_dims = x_grad->dims(); + auto dout_dims = out_grad.dims(); + + if (dx_dims.size() == 1) { + auto dx_length = dx_dims[0]; + int dx_stride = phi::funcs::ComputeStride(0, dx_dims); + + auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims); + auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims); + dout_data += + (offset >= 0 ? offset * dout_stride_1 : -offset * dout_stride_0); + + for (int i = 0; i < dx_length; i++) { + dx_data[i * dx_stride] = dout_data[i * (dout_stride_0 + dout_stride_1)]; + } + } else { + phi::funcs::SetConstant set_padding_value; + set_padding_value(dev_ctx, x_grad, static_cast(0)); + + int dx_stride_0 = phi::funcs::ComputeStride(0, dx_dims); + int dx_stride_1 = phi::funcs::ComputeStride(1, dx_dims); + auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims); + dx_data += (offset >= 0 ? offset * dx_stride_1 : -offset * dx_stride_0); + + auto dout_length = dout_dims[0]; + for (int i = 0; i < dout_length; i++) { + dx_data[i * (dx_stride_0 + dx_stride_1)] = dout_data[i * dout_stride_0]; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(diag_grad, + CPU, + ALL_LAYOUT, + phi::DiagGradKernel, + phi::dtype::float16, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/diag_kernel.cc b/paddle/phi/kernels/cpu/diag_kernel.cc index d1e0b8e31e78f..4b060f0372a5b 100644 --- a/paddle/phi/kernels/cpu/diag_kernel.cc +++ b/paddle/phi/kernels/cpu/diag_kernel.cc @@ -62,5 +62,12 @@ void DiagKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - diag, CPU, ALL_LAYOUT, phi::DiagKernel, int, float, double, int64_t) {} +PD_REGISTER_KERNEL(diag, + CPU, + ALL_LAYOUT, + phi::DiagKernel, + phi::dtype::float16, + int, + float, + double, + int64_t) {} diff --git a/paddle/phi/kernels/diag_grad_kernel.h b/paddle/phi/kernels/diag_grad_kernel.h new file mode 100644 index 0000000000000..b9edab9bec44c --- /dev/null +++ b/paddle/phi/kernels/diag_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DiagGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int offset, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/diag_grad_kernel.cu b/paddle/phi/kernels/gpu/diag_grad_kernel.cu new file mode 100644 index 0000000000000..65bf837e6cf8a --- /dev/null +++ b/paddle/phi/kernels/gpu/diag_grad_kernel.cu @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/diag_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +// Extract the diagonal of a matrix 'dout' to a matrix 'dx' +template +__global__ void ExtractDiagonalKernel(const T* dout, + T* dx, + std::ptrdiff_t start, + std::ptrdiff_t dx_length, + const std::ptrdiff_t sumStride, + const std::ptrdiff_t xStride) { + for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < dx_length; + idx += gridDim.x * blockDim.x) { + const std::ptrdiff_t outOffset = start + sumStride * idx; + dx[xStride * idx] = dout[outOffset]; + } +} + +// Paste a vector 'dout' to the diagonal of a matrix 'dx' +template +__global__ void PasteDiagonalKernel(const T* dout, + T* dx, + std::ptrdiff_t start, + std::ptrdiff_t size, + const std::ptrdiff_t sumStride, + const std::ptrdiff_t outStride) { + for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + std::ptrdiff_t xOffset = start + sumStride * idx; + dx[xOffset] = dout[outStride * idx]; + } +} + +template +void DiagGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int offset, + DenseTensor* x_grad) { + T* dx_data = dev_ctx.template Alloc(x_grad); + auto* dout_data = out_grad.data(); + auto dx_dims = x_grad->dims(); + auto dout_dims = out_grad.dims(); + + auto GetBlockGridSize = [&dev_ctx](int64_t size) { + const int64_t block_size = + std::min(size, static_cast(dev_ctx.GetMaxThreadsPerBlock())); + int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (size + block_size - 1) / block_size); + return std::tuple{block_size, grid_size}; + }; + + if (dx_dims.size() == 1) { + auto dx_length = dx_dims[0]; + auto size = (offset > 0) ? dx_length + offset : dx_length - offset; + int dx_stride = phi::funcs::ComputeStride(0, dx_dims); + if (size > 0) { + auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims); + auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims); + auto start = + (offset >= 0 ? offset * dout_stride_1 : -offset * dout_stride_0); + + std::tuple block_grid_size = GetBlockGridSize(size); + ExtractDiagonalKernel<<(block_grid_size), + std::get<0>(block_grid_size), + 0, + dev_ctx.stream()>>>( + dout_data, + dx_data, + start, + dx_length, + dout_stride_0 + dout_stride_1, + dx_stride); + } + } else { + phi::funcs::SetConstant set_padding_value; + set_padding_value(dev_ctx, x_grad, static_cast(0)); + + int dx_stride_0 = phi::funcs::ComputeStride(0, dx_dims); + int dx_stride_1 = phi::funcs::ComputeStride(1, dx_dims); + int64_t size; + if (offset > 0) { + size = std::min(dx_dims[0], dx_dims[1] - offset); + } else { + size = std::min(dx_dims[0] + offset, dx_dims[1]); + } + + if (size > 0) { + auto start = (offset >= 0 ? offset * dx_stride_1 : -offset * dx_stride_0); + auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims); + std::tuple block_grid_size = GetBlockGridSize(size); + PasteDiagonalKernel<<(block_grid_size), + std::get<0>(block_grid_size), + 0, + dev_ctx.stream()>>>(dout_data, + dx_data, + start, + size, + dx_stride_0 + dx_stride_1, + dout_stride_0); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(diag_grad, + GPU, + ALL_LAYOUT, + phi::DiagGradKernel, + phi::dtype::float16, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/diag_kernel.cu b/paddle/phi/kernels/gpu/diag_kernel.cu index fc70639787173..95d3d3365d91b 100644 --- a/paddle/phi/kernels/gpu/diag_kernel.cu +++ b/paddle/phi/kernels/gpu/diag_kernel.cu @@ -130,5 +130,12 @@ void DiagKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - diag, GPU, ALL_LAYOUT, phi::DiagKernel, int, int64_t, float, double) {} +PD_REGISTER_KERNEL(diag, + GPU, + ALL_LAYOUT, + phi::DiagKernel, + phi::dtype::float16, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/ops/compat/diag_sig.cc b/paddle/phi/ops/compat/diag_sig.cc index 0a14b9095c834..f3245b922c0d9 100644 --- a/paddle/phi/ops/compat/diag_sig.cc +++ b/paddle/phi/ops/compat/diag_sig.cc @@ -20,8 +20,15 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("diag", {"X"}, {"offset", "padding_value"}, {"Out"}); } +KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(diag_v2, diag); +PD_REGISTER_BASE_KERNEL_NAME(diag_v2_grad, diag_grad); PD_REGISTER_ARG_MAPPING_FN(diag_v2, phi::DiagOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(diag_v2_grad, phi::DiagGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_diag_v2.py b/python/paddle/fluid/tests/unittests/test_diag_v2.py index 0371fa054282b..9f727608f816c 100644 --- a/python/paddle/fluid/tests/unittests/test_diag_v2.py +++ b/python/paddle/fluid/tests/unittests/test_diag_v2.py @@ -44,6 +44,10 @@ def test_check_output(self): paddle.enable_static() self.check_output(check_eager=True) + def test_check_grad(self): + paddle.enable_static() + self.check_grad(['X'], 'Out', check_eager=True) + def init_config(self): pass @@ -62,14 +66,14 @@ def init_config(self): class TestDiagV2OpCase3(TestDiagV2Op): def init_config(self): - self.x = np.random.randint(-10, 10, size=(10, 10)) + self.x = np.random.randint(-10, 10, size=(10, 10)).astype("float64") self.out = np.diag(self.x, self.offset) class TestDiagV2OpCase4(TestDiagV2Op): def init_config(self): self.x = np.random.rand(100) - self.padding_value = 8 + self.padding_value = 2 n = self.x.size self.out = self.padding_value * np.ones((n, n)) + np.diag( self.x, self.offset) - np.diag(self.padding_value * np.ones(n))