Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lw trt #5

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class DepthwiseConv2dTransposeOpPattern
return false;
}
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
Expand Down Expand Up @@ -795,7 +796,7 @@ class SplitWithNumOpPattern
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
return false;
}

if (!op->HasAttribute("num") ) {
VLOG(3)<< "split_with_num op must has num attributes";
return false;
Expand Down Expand Up @@ -826,6 +827,56 @@ class SplitWithNumOpPattern

}
};
class GreaterEqualOpPattern : public pir::OpRewritePattern<paddle::dialect::GreaterEqualOp> {
public:
using pir::OpRewritePattern<paddle::dialect::GreaterEqualOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::GreaterEqualOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8400)
VLOG(3) << "GreaterEqualOp is not supported when TensorRT < 8.4";
return false;
#else
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()){
VLOG(3)<< "Greate_equal op do not support bool datatype";
return false;
}
#endif
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class MultiplyOpPattern : public pir::OpRewritePattern<paddle::dialect::MultiplyOp> {
public:
using pir::OpRewritePattern<paddle::dialect::MultiplyOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::MultiplyOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if(x_dtype.isa<pir::BoolType>()){
VLOG(3) << "elementwise_mul do not support boolean datatype.";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};


class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -878,6 +929,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<CastOpPattern>(context));
ps.Add(std::make_unique<SplitOpPattern>(context));
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
ps.Add(std::make_unique<GreaterEqualOpPattern>(context));
ps.Add(std::make_unique<MultiplyOpPattern>(context));
return ps;
}
};
Expand All @@ -891,4 +944,4 @@ std::unique_ptr<Pass> CreateTrtOpMarkerPass() {
}
} // namespace pir

REGISTER_IR_PASS(trt_op_marker_pass, TrtOpMarkerPass);
REGISTER_IR_PASS(trt_op_marker_pass, TrtOpMarkerPass);
Loading