diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index d3adccff73337..ae57b68ad57ee 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -46,7 +46,7 @@ USE_OP_ITSELF(elementwise_add_grad); USE_OP_ITSELF(matmul_grad); USE_OP_ITSELF(square); USE_OP_ITSELF(transpose2_grad); -USE_OP(concat_grad); +USE_OP_ITSELF(concat_grad); USE_OP_ITSELF(elementwise_mul_grad); USE_OP_ITSELF(sigmoid_grad); USE_OP_ITSELF(tanh_grad); @@ -67,6 +67,7 @@ PD_DECLARE_KERNEL(transpose, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(reshape, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(split, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT); diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 059fafa3e7f4d..a467f2dbee7c9 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -216,18 +216,3 @@ REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, ops::ConcatDoubleGradOpMaker, ops::ConcatDoubleGradOpMaker, ops::ConcatOpGradNoNeedBufferVarInferer); - -REGISTER_OP_CPU_KERNEL( - concat_grad, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel>, - ops::ConcatGradKernel>); diff --git a/paddle/fluid/operators/concat_op.cu.cc b/paddle/fluid/operators/concat_op.cu.cc deleted file mode 100644 index f7b64f16e2d8b..0000000000000 --- a/paddle/fluid/operators/concat_op.cu.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/concat_op.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - concat_grad, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel, - ops::ConcatGradKernel>, - ops::ConcatGradKernel>); diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index ec43e2ad374db..50aca54c12dec 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -39,62 +39,6 @@ static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { } return axis > 0 ? axis : 0; } -template -class ConcatGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); - auto outs = - ctx.MultiOutput(framework::GradVarName("X")); - - { - auto dx = outs; - auto x = ins; - for (size_t i = 0; i < dx.size(); ++i) { - if (dx[i] != nullptr) { - dx[i]->set_lod(x[i]->lod()); - } - } - } - PADDLE_ENFORCE_NOT_NULL(ins[0], - platform::errors::NotFound( - "The first input tensor is not initalized.")); - - auto axis = ctx.Attr("axis"); - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - } - axis = ComputeAxis(static_cast(axis), - static_cast(ins[0]->dims().size())); - // get output tensor that the name is not kEmptyVarName - std::vector outputs; - for (size_t j = 0; j < outs.size(); ++j) { - if (out_var_names[j] != framework::kEmptyVarName && - outs[j]->numel() != 0UL) { - outs[j]->mutable_data(ctx.GetPlace()); - outputs.push_back(outs[j]); - } else { - outputs.push_back(nullptr); - } - } - auto& dev_ctx = ctx.template device_context(); - - // Sometimes direct copies will be faster, this maybe need deeply analysis. - if (axis == 0 && outs.size() < 10) { - std::vector ref_shape; - ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end()); - StridedMemcpyWithAxis0(dev_ctx, *out_grad, ref_shape, &outputs); - } else { - math::SplitFunctor split_functor; - split_functor(dev_ctx, *out_grad, ctx.MultiInput("X"), - static_cast(axis), &outputs); - } - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/concat_grad_kernel.h b/paddle/phi/kernels/concat_grad_kernel.h new file mode 100644 index 0000000000000..e407d73bb49ee --- /dev/null +++ b/paddle/phi/kernels/concat_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/kernels/empty_kernel.h" +namespace phi { + +template +void ConcatGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const Scalar& axis_scalar, + std::vector x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/concat_grad_kernel.cc b/paddle/phi/kernels/cpu/concat_grad_kernel.cc new file mode 100644 index 0000000000000..56ed95769fef4 --- /dev/null +++ b/paddle/phi/kernels/cpu/concat_grad_kernel.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/concat_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(concat_grad, + CPU, + ALL_LAYOUT, + phi::ConcatGradKernel, + double, + float, + bool, + int64_t, + int, + uint8_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/concat_grad_kernel.cu b/paddle/phi/kernels/gpu/concat_grad_kernel.cu new file mode 100644 index 0000000000000..2445978daca46 --- /dev/null +++ b/paddle/phi/kernels/gpu/concat_grad_kernel.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/concat_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(concat_grad, + GPU, + ALL_LAYOUT, + phi::ConcatGradKernel, + float, + double, + bool, + int64_t, + int, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/concat_grad_kernel_impl.h b/paddle/phi/kernels/impl/concat_grad_kernel_impl.h new file mode 100644 index 0000000000000..e89920340ff18 --- /dev/null +++ b/paddle/phi/kernels/impl/concat_grad_kernel_impl.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" + +namespace phi { + +template +void ConcatGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const Scalar& axis_scalar, + std::vector x_grad) { + auto outs = x_grad; + { + auto dx = x_grad; + for (size_t i = 0; i < dx.size(); ++i) { + if (dx[i] != nullptr) { + dx[i]->set_lod(x[i]->lod()); + } + } + } + PADDLE_ENFORCE_NOT_NULL( + x[0], phi::errors::NotFound("The first input tensor is not initalized.")); + + auto axis = axis_scalar.to(); + axis = funcs::ComputeAxis(static_cast(axis), + static_cast(x[0]->dims().size())); + // get output tensor that the name is not kEmptyVarName + std::vector outputs; + for (size_t j = 0; j < outs.size(); ++j) { + if (outs[j] && outs[j]->numel() != 0UL) { + dev_ctx.template Alloc(outs[j]); + + outputs.push_back(outs[j]); + } else { + outputs.push_back(nullptr); + } + } + + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + std::vector ref_shape; + ref_shape.insert(ref_shape.begin(), x.begin(), x.end()); + paddle::operators::StridedMemcpyWithAxis0( + dev_ctx, out_grad, ref_shape, &outputs); + } else { + phi::funcs::SplitFunctor split_functor; + split_functor(dev_ctx, out_grad, x, static_cast(axis), &outputs); + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/concat_sig.cc b/paddle/phi/ops/compat/concat_sig.cc index 21e653ccfe90f..d443f521c6146 100644 --- a/paddle/phi/ops/compat/concat_sig.cc +++ b/paddle/phi/ops/compat/concat_sig.cc @@ -23,6 +23,20 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("concat", {"X"}, {"axis"}, {"Out"}); } +KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("concat_grad", + {"X", {GradVarName("Out")}}, + {"AxisTensor"}, + {{GradVarName("X")}}); + } + return KernelSignature("concat_grad", + {"X", {GradVarName("Out")}}, + {"axis"}, + {{GradVarName("X")}}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(concat, phi::ConcatOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(concat_grad, phi::ConcatGradOpArgumentMapping);