Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] fix name in alpha_equal #2674

Merged
merged 1 commit into from
Feb 26, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,20 +217,20 @@ class AlphaEqualHandler:
return false;
}

bool VisitType_(const GlobalTypeVarNode* op, const Type& t2) final {
return GetRef<Type>(op) == t2;
bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
return GetRef<Type>(lhs) == other;
}

bool VisitType_(const TypeCallNode* op, const Type& t2) final {
const TypeCallNode* pt = t2.as<TypeCallNode>();
if (pt == nullptr
|| op->args.size() != pt->args.size()
|| !TypeEqual(op->func, pt->func)) {
bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
const TypeCallNode* rhs = other.as<TypeCallNode>();
if (rhs == nullptr
|| lhs->args.size() != rhs->args.size()
|| !TypeEqual(lhs->func, rhs->func)) {
return false;
}

for (size_t i = 0; i < op->args.size(); ++i) {
if (!TypeEqual(op->args[i], pt->args[i])) {
for (size_t i = 0; i < lhs->args.size(); ++i) {
if (!TypeEqual(lhs->args[i], rhs->args[i])) {
return false;
}
}
Expand Down Expand Up @@ -369,8 +369,8 @@ class AlphaEqualHandler:
}
}

bool VisitExpr_(const OpNode* op, const Expr& other) final {
return op == other.get();
bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
return lhs == other.get();
}

bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
Expand All @@ -389,80 +389,80 @@ class AlphaEqualHandler:
}
}

bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final {
if (const RefCreateNode* nr = e2.as<RefCreateNode>()) {
return ExprEqual(op->value, nr->value);
bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
return ExprEqual(lhs->value, rhs->value);
} else {
return false;
}
}

bool VisitExpr_(const RefReadNode* op, const Expr& e2) final {
if (const RefReadNode* r = e2.as<RefReadNode>()) {
return ExprEqual(op->ref, r->ref);
bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
if (const RefReadNode* rhs = other.as<RefReadNode>()) {
return ExprEqual(lhs->ref, rhs->ref);
} else {
return false;
}
}

bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final {
if (const RefWriteNode* r = e2.as<RefWriteNode>()) {
return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value);
bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, rhs->value);
} else {
return false;
}
}

bool VisitExpr_(const ConstructorNode* op, const Expr& e2) final {
return GetRef<Expr>(op) == e2;
bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
return GetRef<Expr>(lhs) == other;
}

bool ClauseEqual(const Clause& l, const Clause& r) {
return PatternEqual(l->lhs, r->lhs) && ExprEqual(l->rhs, r->rhs);
bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
}

bool PatternEqual(const Pattern& l, const Pattern& r) {
return VisitPattern(l, r);
bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
return VisitPattern(lhs, rhs);
}

bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) final {
return r.as<PatternWildcardNode>();
bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
return other.as<PatternWildcardNode>();
}

bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) final {
if (const auto* r = e2.as<PatternVarNode>()) {
return MergeVarDecl(op->var, r->var);
bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
if (const auto* rhs = other.as<PatternVarNode>()) {
return MergeVarDecl(lhs->var, rhs->var);
}
return false;
}

bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) final {
const auto* r = e2.as<PatternConstructorNode>();
if (r == nullptr
|| !ExprEqual(op->constructor, r->constructor)
|| op->patterns.size() != r->patterns.size()) {
bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) final {
const auto* rhs = other.as<PatternConstructorNode>();
if (rhs == nullptr
|| !ExprEqual(lhs->constructor, rhs->constructor)
|| lhs->patterns.size() != rhs->patterns.size()) {
return false;
}

for (size_t i = 0; i < op->patterns.size(); i++) {
if (!PatternEqual(op->patterns[i], r->patterns[i])) {
for (size_t i = 0; i < lhs->patterns.size(); i++) {
if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
return false;
}
}
return true;
}

bool VisitExpr_(const MatchNode* op, const Expr& e2) final {
const MatchNode* r = e2.as<MatchNode>();
bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
const MatchNode* rhs = other.as<MatchNode>();

if (r == nullptr
|| !ExprEqual(op->data, r->data)
|| op->clauses.size() != r->clauses.size()) {
if (rhs == nullptr
|| !ExprEqual(lhs->data, rhs->data)
|| lhs->clauses.size() != rhs->clauses.size()) {
return false;
}

for (size_t i = 0; i < op->clauses.size(); ++i) {
if (!ClauseEqual(op->clauses[i], r->clauses[i])) {
for (size_t i = 0; i < lhs->clauses.size(); ++i) {
if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
return false;
}
}
Expand Down