Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi]Interploatd kernels into phi #40855

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5ebe9ec
add interploate cpu kernel
MingMingShangTian Mar 22, 2022
dd97df2
Merge branch 'develop' into interp_kernels
MingMingShangTian Mar 22, 2022
61a1b10
fix nullptr bug
MingMingShangTian Mar 22, 2022
a561378
Merge branch 'interp_kernels' of https://github.com/MingMingShangTian…
MingMingShangTian Mar 22, 2022
0c1d3d1
add interpolate gpu kernel
MingMingShangTian Mar 22, 2022
d6b39e8
fix unit test error
MingMingShangTian Mar 23, 2022
1896bde
remove raw kernels
MingMingShangTian Mar 23, 2022
2328f2a
add cuda kernel impl
MingMingShangTian Mar 24, 2022
2ffd911
add infermeta
MingMingShangTian Mar 24, 2022
4de6ff6
recover accidentally deleted kernels in interpolate op
MingMingShangTian Mar 24, 2022
a5d5baa
fix grad x_grad name error
MingMingShangTian Mar 24, 2022
d49a7cf
remove interpolate_v2_op.h
MingMingShangTian Mar 24, 2022
f5cb832
rm unused codes
MingMingShangTian Mar 24, 2022
0dc7845
fix xpu build error
MingMingShangTian Mar 24, 2022
ed038f8
fix build error
MingMingShangTian Mar 25, 2022
e55ea20
fix namespace error
MingMingShangTian Mar 25, 2022
030df75
Merge branch 'develop' into interp_kernels
MingMingShangTian Mar 28, 2022
84ec6eb
add register header for nup
MingMingShangTian Mar 28, 2022
8a65ebf
Merge branch 'develop' into interp_kernels
MingMingShangTian Mar 29, 2022
bf3bbe6
fix infermeta error
MingMingShangTian Mar 29, 2022
f9ceb09
Merge branch 'develop' into interp_kernels
MingMingShangTian Mar 30, 2022
ec78393
modify by review
MingMingShangTian Mar 30, 2022
550dd7e
add the missing args in test_trt_convert_nearest_interp_v2
MingMingShangTian Mar 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2104,7 +2104,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) {
typeid(paddle::optional<const phi::SelectedRows&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<
const std::vector<const phi::DenseTensor*>>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
Expand Down Expand Up @@ -2366,6 +2370,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::type_index(typeid(std::vector<std::string>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr_it->second));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ void BuildDygraphPhiKernelContext(
auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
continue;
} else if (input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<
const std::vector<const phi::DenseTensor*>>))) {
kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
continue;
} else {
PADDLE_THROW(phi::errors::NotFound(
"Can not find input variable '%s' for %s OP, please check whether "
Expand Down Expand Up @@ -545,6 +553,9 @@ void BuildDygraphPhiKernelContext(
std::type_index(typeid(std::vector<std::string>))) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
Expand Down
65 changes: 28 additions & 37 deletions paddle/fluid/operators/interpolate_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/interpolate_v2_op.h"
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Expand Down Expand Up @@ -722,64 +726,51 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateV2GradNoNeedBufferVarsInferer,
// not
// compatible with interp_op, so a new one is added in paddle2.0
namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(bilinear_interp_v2, BilinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(nearest_interp_v2, NearestInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(trilinear_interp_v2,
TrilinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(bicubic_interp_v2, BicubicInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(linear_interp_v2, LinearInterpInferShapeFunctor,
PD_INFER_META(phi::InterpolateInferMeta));

REGISTER_OPERATOR(bilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
BilinearInterpInferShapeFunctor);
REGISTER_OPERATOR(bilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
NearestInterpInferShapeFunctor);
REGISTER_OPERATOR(nearest_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
TrilinearInterpInferShapeFunctor);
REGISTER_OPERATOR(trilinear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
BicubicInterpInferShapeFunctor);
REGISTER_OPERATOR(bicubic_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<int>,
ops::InterpolateV2Kernel<int64_t>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(nearest_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(trilinear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OPERATOR(linear_interp_v2, ops::InterpolateV2Op,
ops::InterpolateV2OpMaker,
ops::InterpolateV2GradMaker<paddle::framework::OpDesc>,
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>);
ops::InterpolateV2GradMaker<paddle::imperative::OpBase>,
LinearInterpInferShapeFunctor);
REGISTER_OPERATOR(linear_interp_v2_grad, ops::InterpolateV2OpGrad,
ops::InterpolateV2GradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>,
ops::InterpolateV2Kernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(linear_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2, ops::InterpolateV2Kernel<float>,
ops::InterpolateV2Kernel<double>);
REGISTER_OP_CPU_KERNEL(bicubic_interp_v2_grad,
ops::InterpolateV2GradKernel<float>,
ops::InterpolateV2GradKernel<double>);
Loading