Skip to content

Commit

Permalink
[Need review] Added conv + hard_sigmoid oneDNN fuse pass (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#36869)

* added conv + hard_sigmoid fuse pass

* Removed IsOptional() statements

* Reverted removing optional
  • Loading branch information
jakpiase authored and piotrekobi committed Nov 3, 2021
1 parent 7b9207f commit f85364e
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("beta"));

if (activation_type() == "hard_sigmoid") {
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("slope"));
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("offset"));
}

GraphSafeRemoveNodes(graph, {activation, conv_out});

PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
Expand Down Expand Up @@ -213,6 +220,26 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
.End();
}

Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// optional, default=0.2
.AddAttr("slope")
.IsOptional()
.IsType<float>()
.End()
// optional, default=0.5
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down Expand Up @@ -259,3 +286,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_swish", 0));

REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSigmoidFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_sigmoid", 0));
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass {
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; }
};
/*
* Fuse Conv and HardSigmoid class
*/
class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
public:
Conv2DHardSigmoidFusePass();
std::string activation_type() const { return "hard_sigmoid"; }
};

} // namespace ir
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) {
MainTest("hard_swish");
}
TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) {
MainTest("hard_sigmoid");
}

} // namespace ir
} // namespace framework
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/operators/compat/hard_sigmoid.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
type: "hard_sigmoid"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
attrs {
name: "slope"
type: FLOAT
}
attrs {
name: "offset"
type: FLOAT
}
}
10 changes: 6 additions & 4 deletions paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,23 +475,25 @@ class ConvMKLDNNHandlerT
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
constexpr float scale = 1.0f;
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(
scale, mkldnn::algorithm::eltwise_hardswish, fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_linear,
fuse_alpha, fuse_beta);
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_clip,
0.0f, 1.0f);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,14 @@ def set_params(self):
self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass'


class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "hard_sigmoid"
self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass'


if __name__ == "__main__":
unittest.main()

0 comments on commit f85364e

Please sign in to comment.