-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
supplement the function of slice. #34172
Merged
Merged
Changes from 11 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
f41ce1c
supplement the function of slice
hbwx24 3967d6b
edit unittest
hbwx24 661eb52
deal with conflict
hbwx24 0a476ea
strided_slice_op support .
hbwx24 ae19cc0
Merge remote-tracking branch 'upstream/develop' into slice/static_get…
hbwx24 0d14ffc
polish error message.
hbwx24 6e3c267
polish error message.
hbwx24 9812033
polish code.
hbwx24 8b94458
polish unittest.
hbwx24 39aa575
polish code.
hbwx24 3b6ad37
polish code
hbwx24 775b464
polish error message.
hbwx24 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,13 @@ class StridedSliceOp : public framework::OperatorWithKernel { | |
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice"); | ||
|
||
auto input_var_type = ctx->GetInputsVarType("Input")[0]; | ||
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { | ||
if (ctx->IsRuntime()) { | ||
// shape is determined by Runtime. | ||
return; | ||
} | ||
} | ||
auto in_dims = ctx->GetInputDim("Input"); | ||
PADDLE_ENFORCE_LT( | ||
in_dims.size(), 7, | ||
|
@@ -154,6 +160,26 @@ class StridedSliceOp : public framework::OperatorWithKernel { | |
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
auto *in_var = ctx.InputVar("Input"); | ||
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>(); | ||
if (is_in_var_array) { | ||
auto &tensor_array = in_var->Get<framework::LoDTensorArray>(); | ||
for (auto &tensor : tensor_array) { | ||
if (!platform::is_cuda_pinned_place(tensor.place())) { | ||
PADDLE_ENFORCE_EQ( | ||
platform::is_same_place(tensor.place(), | ||
ctx.device_context().GetPlace()), | ||
true, platform::errors::InvalidArgument( | ||
"Place of context is %s. Place of context is %s. They " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有一个place是tensor的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, thx. |
||
"are should be same, but reveived different place.", | ||
string::to_string(ctx.device_context().GetPlace()), | ||
string::to_string(tensor.place()))); | ||
} | ||
} | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), | ||
ctx.device_context()); | ||
} | ||
// NOTE: cuda pinned tensor need to copy its data to target place | ||
auto in_tensor = ctx.Input<Tensor>("Input"); | ||
if (platform::is_cuda_pinned_place(in_tensor->place())) { | ||
|
@@ -179,6 +205,14 @@ class StridedSliceOp : public framework::OperatorWithKernel { | |
} | ||
}; | ||
|
||
class StridedSliceOpVarTypeInference : public framework::VarTypeInference { | ||
public: | ||
void operator()(framework::InferVarTypeContext *ctx) const override { | ||
ctx->SetOutputType("Out", ctx->GetInputType("Input")); | ||
ctx->SetOutputDataType("Out", ctx->GetInputDataType("Input")); | ||
} | ||
}; | ||
|
||
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
|
@@ -259,6 +293,13 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { | |
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
"Out@GRAD", "StridedSliceGrad"); | ||
|
||
auto input_var_type = ctx->GetInputsVarType("Input")[0]; | ||
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { | ||
if (ctx->IsRuntime()) { | ||
// shape is determined by Runtime | ||
return; | ||
} | ||
} | ||
auto x_dims = ctx->GetInputDim("Input"); | ||
auto x_grad_name = framework::GradVarName("Input"); | ||
if (ctx->HasOutput(x_grad_name)) { | ||
|
@@ -308,6 +349,16 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpMaker<T> { | |
bind->SetType("strided_slice_grad"); | ||
} | ||
}; | ||
class StridedSliceGradOpVarTypeInference : public framework::VarTypeInference { | ||
public: | ||
void operator()(framework::InferVarTypeContext *ctx) const override { | ||
ctx->SetOutputType(framework::GradVarName("Input"), | ||
ctx->GetInputType(framework::GradVarName("Out"))); | ||
ctx->SetOutputDataType( | ||
framework::GradVarName("Input"), | ||
ctx->GetInputDataType(framework::GradVarName("Out"))); | ||
} | ||
}; | ||
|
||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer, | ||
"Input"); | ||
|
@@ -318,9 +369,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer, | |
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, | ||
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>, | ||
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>); | ||
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>, | ||
ops::StridedSliceOpVarTypeInference); | ||
|
||
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad, | ||
ops::StridedSliceOpGradNoNeedBufferVarsInferer); | ||
ops::StridedSliceOpGradNoNeedBufferVarsInferer, | ||
ops::StridedSliceGradOpVarTypeInference); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
strided_slice, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
结合下面的code,会不会有这种情况,lodtensorarray里面tensor的place是cuda_pinned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.