diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index f79a1ab8fe3b..6022267406df 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -45,7 +45,7 @@ class IRTransformer final : public IRMutator { } private: - template + template T MutateInternal(T node) { if (only_enable_.size() && !only_enable_.count(node->type_index())) { @@ -89,11 +89,11 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) static FMutateStmt inst; return inst; } -inline Array MutateArray(Array arr, IRMutator *m) { - return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); }); +inline Array MutateArray(Array arr, IRMutator* m) { + return UpdateArray(arr, [&m](const Expr& e) { return m->Mutate(e); }); } -inline Array MutateIterVarArr(Array rdom, IRMutator *m) { +inline Array MutateIterVarArr(Array rdom, IRMutator* m) { std::vector new_dom(rdom.size()); bool changed = false; for (size_t i = 0; i < rdom.size(); i++) { @@ -133,7 +133,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const LetStmt* op, const Stmt& s) { Expr value = this->Mutate(op->value); Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && @@ -144,7 +144,7 @@ Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const For *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const For* op, const Stmt& s) { Expr min = this->Mutate(op->min); Expr extent = this->Mutate(op->extent); Stmt body = this->Mutate(op->body); @@ -185,7 +185,7 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const IfThenElse* op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Stmt then_case = this->Mutate(op->then_case); Stmt else_case; @@ -201,7 +201,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const Store* op, const Stmt& s) { Expr value = this->Mutate(op->value); Expr index = this->Mutate(op->index); Expr pred = this->Mutate(op->predicate); @@ -233,7 +233,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { Expr old_extent = op->bounds[i]->extent; Expr new_min = m->Mutate(old_min); Expr new_extent = m->Mutate(old_extent); - if (!new_min.same_as(old_min)) bounds_changed = true; + if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true; new_bounds.push_back( Range::make_by_min_extent(new_min, new_extent)); @@ -263,7 +263,7 @@ Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) { Expr old_extent = op->bounds[i]->extent; Expr new_min = m->Mutate(old_min); Expr new_extent = m->Mutate(old_extent); - if (!new_min.same_as(old_min)) bounds_changed = true; + if (!new_min.same_as(old_min)) bounds_changed = true; if (!new_extent.same_as(old_extent)) bounds_changed = true; new_bounds.push_back( Range::make_by_min_extent(new_min, new_extent)); @@ -288,7 +288,7 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const AssertStmt* op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Expr message = this->Mutate(op->message); Stmt body = this->Mutate(op->body); @@ -302,7 +302,7 @@ Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const ProducerConsumer* op, const Stmt& s) { Stmt body = this->Mutate(op->body); if (body.same_as(op->body)) { return s; @@ -311,7 +311,7 @@ Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const Evaluate* op, const Stmt& s) { Expr v = this->Mutate(op->value); if (v.same_as(op->value)) { return s; @@ -320,7 +320,7 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const Free* op, const Stmt& s) { return s; } @@ -348,11 +348,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) return m->Mutate_(static_cast(node.get()), e); \ }) -Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { +Expr IRMutator::Mutate_(const Variable* op, const Expr& e) { return e; } -Expr IRMutator::Mutate_(const Load *op, const Expr& e) { +Expr IRMutator::Mutate_(const Load* op, const Expr& e) { Expr index = this->Mutate(op->index); Expr pred = this->Mutate(op->predicate); if (index.same_as(op->index) && pred.same_as(op->predicate)) { @@ -362,7 +362,7 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Let *op, const Expr& e) { +Expr IRMutator::Mutate_(const Let* op, const Expr& e) { Expr value = this->Mutate(op->value); Expr body = this->Mutate(op->body); if (value.same_as(op->value) && @@ -413,8 +413,8 @@ DEFINE_BIOP_EXPR_MUTATE_(GE) DEFINE_BIOP_EXPR_MUTATE_(And) DEFINE_BIOP_EXPR_MUTATE_(Or) -Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { - Array new_axis = MutateIterVarArr(op->axis, this); +Expr IRMutator::Mutate_(const Reduce* op, const Expr& e) { + Array new_axis = MutateIterVarArr(op->axis, this); Array new_source = MutateArray(op->source, this); Expr new_cond = this->Mutate(op->condition); if (op->axis.same_as(new_axis) && @@ -427,7 +427,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { +Expr IRMutator::Mutate_(const Cast* op, const Expr& e) { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return e; @@ -436,7 +436,7 @@ Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Not *op, const Expr& e) { +Expr IRMutator::Mutate_(const Not* op, const Expr& e) { Expr a = this->Mutate(op->a); if (a.same_as(op->a)) { return e; @@ -445,7 +445,7 @@ Expr IRMutator::Mutate_(const Not *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Select *op, const Expr& e) { +Expr IRMutator::Mutate_(const Select* op, const Expr& e) { Expr cond = this->Mutate(op->condition); Expr t = this->Mutate(op->true_value); Expr f = this->Mutate(op->false_value); @@ -458,7 +458,7 @@ Expr IRMutator::Mutate_(const Select *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { +Expr IRMutator::Mutate_(const Ramp* op, const Expr& e) { Expr base = this->Mutate(op->base); Expr stride = this->Mutate(op->stride); if (base.same_as(op->base) && @@ -469,7 +469,7 @@ Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { +Expr IRMutator::Mutate_(const Broadcast* op, const Expr& e) { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return e; @@ -478,7 +478,7 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { } } -Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) { +Expr IRMutator::Mutate_(const Shuffle* op, const Expr& e) { auto new_vec = MutateArray(op->vectors, this); if (new_vec.same_as(op->vectors)) { return e; diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 204c0f75fe4a..d6f163ccedc6 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -43,7 +43,6 @@ class IRApplyVisit : public IRVisitor { std::unordered_set visited_; }; - void PostOrderVisit(const NodeRef& node, std::function fvisit) { IRApplyVisit(fvisit).Visit(node); } @@ -68,7 +67,7 @@ inline void VisitRDom(const Array& rdom, IRVisitor* v) { void IRVisitor::Visit_(const Variable* op) {} -void IRVisitor::Visit_(const LetStmt *op) { +void IRVisitor::Visit_(const LetStmt* op) { this->Visit(op->value); this->Visit(op->body); } @@ -78,14 +77,14 @@ void IRVisitor::Visit_(const AttrStmt* op) { this->Visit(op->body); } -void IRVisitor::Visit_(const For *op) { +void IRVisitor::Visit_(const For* op) { IRVisitor* v = this; v->Visit(op->min); v->Visit(op->extent); v->Visit(op->body); } -void IRVisitor::Visit_(const Allocate *op) { +void IRVisitor::Visit_(const Allocate* op) { IRVisitor* v = this; for (size_t i = 0; i < op->extents.size(); i++) { v->Visit(op->extents[i]); @@ -97,18 +96,18 @@ void IRVisitor::Visit_(const Allocate *op) { } } -void IRVisitor::Visit_(const Load *op) { +void IRVisitor::Visit_(const Load* op) { this->Visit(op->index); this->Visit(op->predicate); } -void IRVisitor::Visit_(const Store *op) { +void IRVisitor::Visit_(const Store* op) { this->Visit(op->value); this->Visit(op->index); this->Visit(op->predicate); } -void IRVisitor::Visit_(const IfThenElse *op) { +void IRVisitor::Visit_(const IfThenElse* op) { this->Visit(op->condition); this->Visit(op->then_case); if (op->else_case.defined()) { @@ -116,14 +115,14 @@ void IRVisitor::Visit_(const IfThenElse *op) { } } -void IRVisitor::Visit_(const Let *op) { +void IRVisitor::Visit_(const Let* op) { this->Visit(op->value); this->Visit(op->body); } void IRVisitor::Visit_(const Free* op) {} -void IRVisitor::Visit_(const Call *op) { +void IRVisitor::Visit_(const Call* op) { VisitArray(op->args, this); } @@ -171,38 +170,38 @@ void IRVisitor::Visit_(const Select* op) { this->Visit(op->false_value); } -void IRVisitor::Visit_(const Ramp *op) { +void IRVisitor::Visit_(const Ramp* op) { this->Visit(op->base); this->Visit(op->stride); } -void IRVisitor::Visit_(const Shuffle *op) { - for (const auto &elem : op->indices) +void IRVisitor::Visit_(const Shuffle* op) { + for (const auto& elem : op->indices) this->Visit(elem); - for (const auto &elem : op->vectors) + for (const auto& elem : op->vectors) this->Visit(elem); } -void IRVisitor::Visit_(const Broadcast *op) { +void IRVisitor::Visit_(const Broadcast* op) { this->Visit(op->value); } -void IRVisitor::Visit_(const AssertStmt *op) { +void IRVisitor::Visit_(const AssertStmt* op) { this->Visit(op->condition); this->Visit(op->message); this->Visit(op->body); } -void IRVisitor::Visit_(const ProducerConsumer *op) { +void IRVisitor::Visit_(const ProducerConsumer* op) { this->Visit(op->body); } -void IRVisitor::Visit_(const Provide *op) { +void IRVisitor::Visit_(const Provide* op) { VisitArray(op->args, this); this->Visit(op->value); } -void IRVisitor::Visit_(const Realize *op) { +void IRVisitor::Visit_(const Realize* op) { for (size_t i = 0; i < op->bounds.size(); i++) { this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->extent); @@ -212,19 +211,19 @@ void IRVisitor::Visit_(const Realize *op) { this->Visit(op->condition); } -void IRVisitor::Visit_(const Prefetch *op) { +void IRVisitor::Visit_(const Prefetch* op) { for (size_t i = 0; i < op->bounds.size(); i++) { this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->extent); } } -void IRVisitor::Visit_(const Block *op) { +void IRVisitor::Visit_(const Block* op) { this->Visit(op->first); this->Visit(op->rest); } -void IRVisitor::Visit_(const Evaluate *op) { +void IRVisitor::Visit_(const Evaluate* op) { this->Visit(op->value); }