Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor][relay pass] Separate analysis and transform passes #5035

Merged
merged 2 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ file(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc
)
file(GLOB_RECURSE RELAY_PASS_SRCS
src/relay/pass/*.cc
src/relay/analysis/*.cc
src/relay/transforms/*.cc
src/relay/quantize/*.cc
)
file(GLOB RELAY_BACKEND_SRCS
src/relay/backend/*.cc
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/relay/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tvm.ir import RelayExpr, IRModule

from . import _analysis
from . import _make
from .ty import Type
from .feature import Feature

Expand Down Expand Up @@ -237,7 +236,7 @@ def alpha_equal(lhs, rhs):
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
return bool(_analysis._alpha_equal(lhs, rhs))


def assert_alpha_equal(lhs, rhs):
Expand All @@ -251,7 +250,7 @@ def assert_alpha_equal(lhs, rhs):
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_alpha_equal(lhs, rhs)
_analysis._assert_alpha_equal(lhs, rhs)


def graph_equal(lhs, rhs):
Expand All @@ -273,7 +272,7 @@ def graph_equal(lhs, rhs):
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
return bool(_analysis._graph_equal(lhs, rhs))


def assert_graph_equal(lhs, rhs):
Expand All @@ -290,7 +289,7 @@ def assert_graph_equal(lhs, rhs):
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_graph_equal(lhs, rhs)
_analysis._assert_graph_equal(lhs, rhs)


def collect_device_info(expr):
Expand Down
2 changes: 1 addition & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "meta_data.h"
#include "../relay/pass/dependency_graph.h"
#include "../relay/analysis/dependency_graph.h"
#include "../ir/attr_functor.h"

namespace tvm {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file src/tvm/relay/ir/alpha_equal.cc
* \file src/relay/analysis/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
#include <tvm/ir/type_functor.h>
Expand Down Expand Up @@ -593,8 +593,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
}

// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
TVM_REGISTER_GLOBAL("relay._analysis._alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b);
});
Expand All @@ -604,18 +603,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
return AlphaEqual(a, b);
});

TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal")
TVM_REGISTER_GLOBAL("relay._analysis._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});

TVM_REGISTER_GLOBAL("relay._make._graph_equal")
TVM_REGISTER_GLOBAL("relay._analysis._graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b);
});

TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal")
TVM_REGISTER_GLOBAL("relay._analysis._assert_graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file tvm/relay/pass/call_graph.cc
* \file src/relay/analysis/call_graph.cc
* \brief Implementation of APIs to handle the call graph of a Relay module.
*/

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
*/

/*!
* \file tvm/relay/pass/call_graph.h
* \file src/relay/analysis/call_graph.h
* \brief Define data structures for the call graph of a IRModule. It borrows
* the idea how LLVM constructs CallGraph.
*
* https://llvm.org/doxygen/CallGraph_8h_source.html
*/

#ifndef TVM_RELAY_PASS_CALL_GRAPH_H_
#define TVM_RELAY_PASS_CALL_GRAPH_H_
#ifndef TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
#define TVM_RELAY_ANALYSIS_CALL_GRAPH_H_

#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -510,4 +510,4 @@ class CallGraphEntry {

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_CALL_GRAPH_H_
#endif // TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
*/

/*!
* \file tvm/relay/pass/dependency_graph.cc
* \brief
* \file src/relay/analysis/dependency_graph.cc
* \brief Implementation of dependency graph APIs.
*/
#include "dependency_graph.h"
#include <tvm/relay/expr_functor.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
*/

/*!
* \file tvm/relay/pass/dependency_graph.h
* \file src/relay/analysis/dependency_graph.h
* \brief create a dependency graph.
*/
#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#ifndef TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_
#define TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_

#include <tvm/relay/expr.h>
#include <unordered_map>
#include <vector>
#include "let_list.h"
#include "../transforms/let_list.h"
#include "../../support/arena.h"

namespace tvm {
Expand Down Expand Up @@ -72,4 +72,4 @@ class DependencyGraph {

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_
#endif // TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/module.h>
#include "pass_util.h"
#include "../transforms/pass_util.h"

namespace tvm {
namespace relay {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include "pattern_util.h"
#include "../transforms/pattern_util.h"

namespace tvm {
namespace relay {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* \file type_solver.h
* \brief Solver logic for type inference.
*/
#ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_
#define TVM_RELAY_PASS_TYPE_SOLVER_H_
#ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
#define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
Expand All @@ -34,7 +34,6 @@
#include <unordered_set>
#include "../../support/arena.h"


namespace tvm {
namespace relay {

Expand Down Expand Up @@ -219,4 +218,4 @@ class TypeSolver {

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SOLVER_H_
#endif // TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
2 changes: 1 addition & 1 deletion src/relay/pass/util.cc → src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../transforms/pass_util.h"

namespace tvm {
namespace relay {
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#include <vector>
#include "../utils.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../transforms/pass_util.h"
#include "../../op/op_common.h"
#include "compiler.h"

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include "../../../runtime/vm/profiler/vm.h"
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../transforms/pass_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h>

#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../type_relations.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/device_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <tvm/relay/op_attr_types.h>

#include "type_relations.h"
#include "../pass/infer_layout_util.h"
#include "../transforms/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <tvm/relay/attrs/memory.h>

#include "../op_common.h"
#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../type_relations.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/bitserial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/op.h>

#include "../op_common.h"
#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <vector>

#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
#include "convolution.h"

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <vector>
#include <string>
#include "../type_relations.h"
#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
#include "nn.h"

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h>
#include <vector>
#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <vector>

#include "../../pass/infer_layout_util.h"
#include "../../transforms/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <string>
#include <unordered_map>
#include "type_relations.h"
#include "../pass/infer_layout_util.h"
#include "../transforms/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
#include <vector>
#include "../op_common.h"
#include "../../../arith/compute_expr.h"
#include "../../pass/infer_layout_util.h"
#include "../../pass/pattern_util.h"
#include "../../transforms/infer_layout_util.h"
#include "../../transforms/pattern_util.h"
#include "transform.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"
#include "op_common.h"

Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../op/tensor/transform.h"
#include "../../pass/pattern_util.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>
#include "../../op/nn/convolution.h"
#include "../../pass/pattern_util.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../op/nn/nn.h"
#include "../../pass/pattern_util.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/qnn/op/mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>
#include "../../pass/pattern_util.h"
#include "../../transforms/pattern_util.h"
#include "../util.h"
#include "op_common.h"

Expand Down
Loading