Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi]Add diag_v2 grad kernel #40447

Merged
merged 4 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions paddle/fluid/operators/diag_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
Expand Down Expand Up @@ -58,15 +56,56 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
}
};

class DiagV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "X", "X", "DiagV2Grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "DiagV2Grad");

ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};

template <typename T>
class DiagV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("diag_v2_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};

DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagGradV2NoNeedBufferVarsInferer, "X");

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(diag_v2, DiagInferShapeFunctor,
PD_INFER_META(phi::DiagInferMeta));

REGISTER_OPERATOR(
diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DiagInferShapeFunctor);
REGISTER_OPERATOR(diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
ops::DiagV2GradOpMaker<paddle::framework::OpDesc>,
ops::DiagV2GradOpMaker<paddle::imperative::OpBase>,
DiagInferShapeFunctor);

REGISTER_OPERATOR(diag_v2_grad, ops::DiagV2GradOp,
ops::DiagGradV2NoNeedBufferVarsInferer);
1 change: 1 addition & 0 deletions paddle/phi/core/compat/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
* They are marked here uniformly.
*/
const std::unordered_set<std::string> deprecated_op_names({"diag",
"diag_grad",
"flatten",
"flatten_grad",
"isinf",
Expand Down
71 changes: 71 additions & 0 deletions paddle/phi/kernels/cpu/diag_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/diag_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void DiagGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
const T* dout_data = out_grad.data<T>();
auto dx_dims = x_grad->dims();
auto dout_dims = out_grad.dims();

if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
const int& dx_stride = phi::funcs::ComputeStride(0, dx_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前后代码风格保持一致,去除const &

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
dout_data +=
(offset >= 0 ? offset * dout_stride_1 : -offset * dout_stride_0);

for (int i = 0; i < dx_length; i++) {
dx_data[i * dx_stride] = dout_data[i * (dout_stride_0 + dout_stride_1)];
}
} else {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, x_grad, static_cast<T>(0));

const int& dx_stride_0 = phi::funcs::ComputeStride(0, dx_dims);
const int& dx_stride_1 = phi::funcs::ComputeStride(1, dx_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
dx_data += (offset >= 0 ? offset * dx_stride_1 : -offset * dx_stride_0);

auto dout_length = dout_dims[0];
for (int i = 0; i < dout_length; i++) {
dx_data[i * (dx_stride_0 + dx_stride_1)] = dout_data[i * dout_stride_0];
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(diag_grad,
CPU,
ALL_LAYOUT,
phi::DiagGradKernel,
int,
int64_t,
float,
double) {}
28 changes: 28 additions & 0 deletions paddle/phi/kernels/diag_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void DiagGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
DenseTensor* x_grad);

} // namespace phi
138 changes: 138 additions & 0 deletions paddle/phi/kernels/gpu/diag_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/diag_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

// Extract the diagonal of a matrix 'dout' to a matrix 'dx'
template <typename T>
__global__ void ExtractDiagonalKernel(const T* dout,
T* dx,
std::ptrdiff_t start,
std::ptrdiff_t dx_length,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t xStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < dx_length;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t outOffset = start + sumStride * idx;
dx[xStride * idx] = dout[outOffset];
}
}

// Paste a vector 'dout' to the diagonal of a matrix 'dx'
template <typename T>
__global__ void PasteDiagonalKernel(const T* dout,
T* dx,
std::ptrdiff_t start,
std::ptrdiff_t size,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t outStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
std::ptrdiff_t xOffset = start + sumStride * idx;
dx[xOffset] = dout[outStride * idx];
}
}

template <typename T, typename Context>
void DiagGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
auto* dout_data = out_grad.data<T>();
auto dx_dims = x_grad->dims();
auto dout_dims = out_grad.dims();

auto GetBlockGridSize = [&dev_ctx](int64_t size) {
const int64_t block_size =
std::min(size, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};

if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
auto size = (offset > 0) ? dx_length + offset : dx_length - offset;
const int& dx_stride = phi::funcs::ComputeStride(0, dx_dims);
if (size > 0) {
const auto& dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
const auto& dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
auto start =
(offset >= 0 ? offset * dout_stride_1 : -offset * dout_stride_0);

std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
ExtractDiagonalKernel<T><<<std::get<1>(block_grid_size),
std::get<0>(block_grid_size),
0,
dev_ctx.stream()>>>(
dout_data,
dx_data,
start,
dx_length,
dout_stride_0 + dout_stride_1,
dx_stride);
}
} else {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, x_grad, static_cast<T>(0));

const int& dx_stride_0 = phi::funcs::ComputeStride(0, dx_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照之前讨论的,这里const & 可以不适用,风格保持一致,前面的代码没有使用const &

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const int& dx_stride_1 = phi::funcs::ComputeStride(1, dx_dims);
int64_t size;
if (offset > 0) {
size = std::min(dx_dims[0], dx_dims[1] - offset);
} else {
size = std::min(dx_dims[0] + offset, dx_dims[1]);
}

if (size > 0) {
auto start = (offset >= 0 ? offset * dx_stride_1 : -offset * dx_stride_0);
const auto& dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
PasteDiagonalKernel<T><<<std::get<1>(block_grid_size),
std::get<0>(block_grid_size),
0,
dev_ctx.stream()>>>(dout_data,
dx_data,
start,
size,
dx_stride_0 + dx_stride_1,
dout_stride_0);
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(diag_grad,
GPU,
ALL_LAYOUT,
phi::DiagGradKernel,
int,
int64_t,
float,
double) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float16也给注册上去

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

7 changes: 7 additions & 0 deletions paddle/phi/ops/compat/diag_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("diag", {"X"}, {"offset", "padding_value"}, {"Out"});
}

KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")});
}

} // namespace phi

PD_REGISTER_BASE_KERNEL_NAME(diag_v2, diag);
PD_REGISTER_BASE_KERNEL_NAME(diag_v2_grad, diag_grad);

PD_REGISTER_ARG_MAPPING_FN(diag_v2, phi::DiagOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(diag_v2_grad, phi::DiagGradOpArgumentMapping);
8 changes: 6 additions & 2 deletions python/paddle/fluid/tests/unittests/test_diag_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def test_check_output(self):
paddle.enable_static()
self.check_output(check_eager=True)

def test_check_grad(self):
paddle.enable_static()
self.check_grad(['X'], 'Out', check_eager=True)

def init_config(self):
pass

Expand All @@ -62,14 +66,14 @@ def init_config(self):

class TestDiagV2OpCase3(TestDiagV2Op):
def init_config(self):
self.x = np.random.randint(-10, 10, size=(10, 10))
self.x = np.random.randint(-10, 10, size=(10, 10)).astype("float64")
self.out = np.diag(self.x, self.offset)


class TestDiagV2OpCase4(TestDiagV2Op):
def init_config(self):
self.x = np.random.rand(100)
self.padding_value = 8
self.padding_value = 2
n = self.x.size
self.out = self.padding_value * np.ones((n, n)) + np.diag(
self.x, self.offset) - np.diag(self.padding_value * np.ones(n))
Expand Down