From 6fc0ca3ccd2b3e66c54bab38cc8f0f5e91f6fefe Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 18 Oct 2018 15:21:13 -0700 Subject: [PATCH 1/2] [RELAY] IR builder refactor, clean pass --- include/tvm/relay/attrs/nn.h | 2 +- include/tvm/relay/environment.h | 28 +- include/tvm/relay/expr.h | 5 +- include/tvm/relay/pass.h | 23 +- python/tvm/relay/__init__.py | 4 +- python/tvm/relay/env.py | 106 ++--- python/tvm/relay/expr.py | 63 ++- python/tvm/relay/ir_builder.py | 387 ------------------ python/tvm/relay/ir_pass.py | 69 ++-- python/tvm/relay/scope_builder.py | 185 +++++++++ python/tvm/relay/ty.py | 54 ++- src/relay/ir/environment.cc | 93 ++--- src/relay/ir/text_printer.cc | 12 +- src/relay/op/image/resize.cc | 1 + src/relay/op/nn/nn.cc | 5 + src/relay/op/nn/pad.cc | 1 + src/relay/op/nn/pooling.cc | 5 + src/relay/op/nn/upsampling.cc | 1 + src/relay/op/tensor/reduce.cc | 4 +- src/relay/op/tensor/transform.cc | 13 + src/relay/op/type_relations.cc | 177 ++------ src/relay/op/type_relations.h | 27 -- src/relay/op/vision/multibox_op.cc | 1 + src/relay/pass/dead_code.cc | 4 +- src/relay/pass/type_infer.cc | 72 +++- tests/python/relay/test_ir_builder.py | 19 - tests/python/relay/test_ir_nodes.py | 2 +- tests/python/relay/test_ir_text_printer.py | 32 +- tests/python/relay/test_op_level1.py | 339 +++++---------- tests/python/relay/test_op_level2.py | 315 ++++++-------- tests/python/relay/test_op_level3.py | 239 ++++------- tests/python/relay/test_op_level4.py | 223 +++------- tests/python/relay/test_op_level5.py | 54 +-- tests/python/relay/test_pass_alpha_equal.py | 77 ++-- .../relay/test_pass_dead_code_elimination.py | 22 +- tests/python/relay/test_type_infer.py | 131 +++--- tests/python/relay/test_type_solver.py | 2 - 37 files changed, 1089 insertions(+), 1708 deletions(-) delete mode 100644 python/tvm/relay/ir_builder.py create mode 100644 python/tvm/relay/scope_builder.py delete mode 100644 tests/python/relay/test_ir_builder.py diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 5dbaecdc3e78..6b522ef3bfd0 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode { struct LeakyReluAttrs : public tvm::AttrsNode { double alpha; - TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") { + TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") { TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) .describe("Slope coefficient for the negative half axis."); } diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/environment.h index 46cedf12b816..2ed389571ad6 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/environment.h @@ -47,12 +47,13 @@ class EnvironmentNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); - v->Visit("global_map_", &global_map_); + v->Visit("global_var_map_", &global_var_map_); } TVM_DLL static Environment make(tvm::Map global_funcs); - /*! \brief Add a function to the global environment. + /*! + * \brief Add a function to the global environment. * \param var The name of the global function. * \param func The function. * \param update Controls whether you can replace a definition in the @@ -60,39 +61,46 @@ class EnvironmentNode : public RelayNode { */ void Add(const GlobalVar& var, const Function& func, bool update = false); - /*! \brief Update a function in the global environment. + /*! + * \brief Update a function in the global environment. * \param var The name of the global function to update. * \param func The new function. */ void Update(const GlobalVar& var, const Function& func); - /*! \brief Remove a function from the global environment. + /*! + * \brief Remove a function from the global environment. * \param var The name of the global function to update. */ void Remove(const GlobalVar& var); - /*! \brief Lookup a global function by its variable. + /*! + * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ GlobalVar GetGlobalVar(const std::string& str); - /*! \brief Lookup a global function by its variable. + /*! + * \brief Lookup a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ Function Lookup(const GlobalVar& var); - /*! \brief Lookup a global function by its string name + /*! + * \brief Lookup a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ Function Lookup(const std::string& name); - /*! \brief Combine with another Environment. + /*! + * \brief Update the functions inside this environment by + * functions in another environment. * \param other The other environment. */ - void Merge(const Environment& other); + void Update(const Environment& other); static constexpr const char* _type_key = "relay.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); @@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode { /*! \brief A map from string names to global variables that * ensures global uniqueness. */ - tvm::Map global_map_; + tvm::Map global_var_map_; }; struct Environment : public NodeRef { diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 743dc085d035..d0b58e0213c7 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode { int index; void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("tuple", &tuple); + v->Visit("tuple_value", &tuple); v->Visit("index", &index); + v->Visit("_checked_type_", &checked_type_); } TVM_DLL static TupleGetItem make(Expr tuple, int index); - static constexpr const char * _type_key = "relay.GetItem"; + static constexpr const char * _type_key = "relay.TupleGetItem"; TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); }; diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1043e4aaaa4c..04f6a1842ee6 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -12,21 +12,30 @@ namespace tvm { namespace relay { -/*! \brief Infer the type of an expression with the provided environment. +/*! + * \brief Infer the type of an expression. * * The result of type checking is a new expression with unambigous * type information filled in, as well as it's checked type field * populated with the result type. * - * \param env The environment used for global settings and referencing - * global functions. - * - * \param e The expression to type check. + * \param expr The expression to type check. + * \param env The environment used for referencing global functions, can be None. * * \return A type checked expression with its checked_type field populated. */ -Expr InferType(const Environment& env, const Expr& e); -Expr InferType(const Environment& env, const GlobalVar& var, const Function& f); +Expr InferType(const Expr& expr, const Environment& env); +/*! + * \brief Infer the type of a function as if it is mapped to var in the env. + * + * \param f the function. + * \param env The environment used for referencing global functions. + * \param var The global variable corresponding to the function. + * + * \return A type checked Function with its checked_type field populated. + * \note this function mutates env and is not thread-safe. + */ +Function InferType(const Function& f, const Environment& env, const GlobalVar& var); /*! * \brief Check that types are well kinded by applying "kinding rules". diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index d6ecdb7855d8..4e53b6ba9aab 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -5,7 +5,6 @@ from . import expr from . import env from . import ir_pass -from . import ir_builder # Root operators from .op import Op @@ -16,6 +15,8 @@ from . import vision from . import image +from .scope_builder import ScopeBuilder + # Span Span = base.Span @@ -32,6 +33,7 @@ FuncType = ty.FuncType TypeRelation = ty.TypeRelation IncompleteType = ty.IncompleteType +scalar_type = ty.scalar_type # Expr Constant = expr.Constant diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 8c226e509a12..9c3241e18ef8 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -1,31 +1,40 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import """A global environment storing everything needed to interpret or compile a Relay program.""" from .base import register_relay_node, RelayNode +from .._ffi import base as _base from . import _make from . import _env +from . import expr as _expr + @register_relay_node class Environment(RelayNode): - """The global Relay environment containing functions, - options and more. - """ - - def __init__(self, funcs=None): - """Construct an environment. - - Parameters - ------ - funcs : optional, dict - Map of global var to Function + """The global Relay environment containing collection of functions. - Returns - ------ - env: A new environment containing :py:class:`~relay.env.Environment`. - """ - funcs = funcs if funcs else {} - self.__init_handle_by_constructor__(_make.Environment, funcs) + Each global function is identified by an unique tvm.relay.GlobalVar. + tvm.relay.GlobalVar and Environment is necessary in order to enable + recursions in function to avoid cyclic reference in the function.x - def add(self, var, func): + Parameters + ---------- + functions : dict, optional. + Map of global var to Function + """ + def __init__(self, functions=None): + if functions is None: + functions = {} + elif isinstance(functions, dict): + mapped_funcs = {} + for k, v in functions.items(): + if isinstance(k, _base.string_types): + k = _expr.GlobalVar(k) + if not isinstance(k, _expr.GlobalVar): + raise TypeError("Expect functions to be Dict[GlobalVar, Function]") + mapped_funcs[k] = v + functions = mapped_funcs + self.__init_handle_by_constructor__(_make.Environment, functions) + + def __setitem__(self, var, func): """Add a function to the environment. Parameters @@ -36,50 +45,55 @@ def add(self, var, func): func: Function The function. """ - if isinstance(var, str): - var = _env.Environment_GetGlobalVar(self, var) - + if isinstance(var, _base.string_types): + var = _expr.GlobalVar(var) _env.Environment_Add(self, var, func) - def merge(self, other): - """Merge two environments. + def __getitem__(self, var): + """Lookup a global function by name or by variable. Parameters ---------- - other: Environment - The environment to merge into the current Environment. + var: str or GlobalVar + The name or global variable. + + Returns + ------- + func: Function + The function referenced by :code:`var`. """ - return _env.Environment_Merge(self, other) + if isinstance(var, _base.string_types): + return _env.Environment_Lookup_str(self, var) + else: + return _env.Environment_Lookup(self, var) - def global_var(self, name): - """Get a global variable by name. + def update(self, other): + """Insert functions in another Environment to current one. Parameters ---------- - name: str - The name of the global variable. - - Returns - ------- - global_var: GlobalVar - The global variable mapped to :code:`name`. + other: Environment + The environment to merge into the current Environment. """ - return _env.Environment_GetGlobalVar(self, name) + if isinstance(other, dict): + other = Environment(other) + return _env.Environment_Update(self, other) - def __getitem__(self, var): - """Lookup a global function by name or by variable. + def get_global_var(self, name): + """Get a global variable in the function by name. Parameters ---------- - var: str or GlobalVar - The name or global variable. + name: str + The name of the global variable. Returns ------- - func: Function - The function referenced by :code:`var`. + global_var: GlobalVar + The global variable mapped to :code:`name`. + + Raises + ------ + tvm.TVMError if we cannot find corresponding global var. """ - if isinstance(var, str): - return _env.Environment_Lookup_str(self, var) - else: - return _env.Environment_Lookup(self, var) + return _env.Environment_GetGlobalVar(self, name) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 9807fab45089..36116d07d601 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -28,9 +28,6 @@ def checked_type(self): " the checked_type for this node") return ret - def __call__(self, *args): - return Call(self, args, None, None) - @register_relay_node class Constant(Expr): @@ -57,6 +54,14 @@ class Tuple(Expr): def __init__(self, fields): self.__init_handle_by_constructor__(_make.Tuple, fields) + def __getitem__(self, index): + if index >= len(self): + raise IndexError("Tuple index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + @register_relay_node class Var(Expr): @@ -95,6 +100,16 @@ class GlobalVar(Expr): def __init__(self, name_hint): self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) + def __call__(self, *args): + """Invoke the gobal function. + + Parameters + ---------- + args: List[relay.Expr] + Arguments. + """ + return Call(self, args, None, None) + @register_relay_node class Function(Expr): @@ -126,6 +141,16 @@ def __init__(self, self.__init_handle_by_constructor__( _make.Function, params, body, ret_type, type_params) + def __call__(self, *args): + """Invoke the gobal function. + + Parameters + ---------- + args: List[relay.Expr] + Arguments. + """ + return Call(self, args, None, None) + @register_relay_node class Call(Expr): @@ -238,11 +263,17 @@ def asnode(self): return self.tuple_value - def __getitem__(self, key): - return self.tuple_value.fields[key] + def __getitem__(self, index): + if index >= len(self): + raise IndexError("Tuple index out of range") + return TupleGetItem(self.tuple_value, index) def __len__(self): - return len(self.tuple_value.fields) + return self.size + + def __repr__(self): + return ("TupleWrapper(" + self.tuple_value.__repr__() + + ", " + self.size + ")") def var(name_hint, @@ -304,13 +335,27 @@ def const(value, dtype=None): dtype: str, optional The data type of the value. + + Note + ---- + When dtype is None, we use the following rule: + + - int maps to "int32" + - float maps to "float32" + - bool maps to "bool" + - other using the same default rule as numpy. """ - if isinstance(value, _base.numeric_types): - value = _np.array(value, dtype=dtype) - elif isinstance(value, (bool, list)): + if isinstance(value, (_base.numeric_types, (bool, list))): value = _np.array(value, dtype=dtype) + # convert default to int32 and float32 + if dtype is None: + if value.dtype == "float64": + value = value.astype("float32") + elif value.dtype == "int64": + value = value.astype("int32") if isinstance(value, (_np.ndarray, _np.generic)): value = _nd.array(value) + if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") return Constant(value) diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py deleted file mode 100644 index d2771926e58f..000000000000 --- a/python/tvm/relay/ir_builder.py +++ /dev/null @@ -1,387 +0,0 @@ -# pylint: disable=no-else-return -"""IR builder for the Relay IR. - -Enables users to construct Relay programs with a Python API. -""" -from collections import OrderedDict -import numpy as np -import tvm -from .ty import Type, FuncType, TensorType -from .expr import Expr, Constant, Let, Var, Function, If -from .env import Environment - - -def _convert_to_value(arg, ctxt=tvm.cpu(0)): - # type: (Any, tvm.Context) -> tvm.nd.NDArray - """Convert Python values into the appropriate types - for the Relay evaluator. - """ - if isinstance(arg, bool): # bool is subclass of int - return tvm.nd.array(np.array(arg, dtype='uint8'), ctxt) - elif isinstance(arg, int): - return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) - elif isinstance(arg, float): - return tvm.nd.array(arg, ctxt) - elif isinstance(arg, np.ndarray): - return tvm.nd.array(arg, ctxt) - elif isinstance(arg, tvm.ndarray.NDArray): - return arg - else: - # raise Exception(f"can't convert {type(arg)} to a Relay AST") - raise Exception("unsupported argument type {0}".format(type(arg))) - - -def _convert_type(rtype): - if isinstance(rtype, str): - return scalar_type(rtype) - elif isinstance(rtype, Type): - return rtype - else: - raise Exception( - "unsupported conversion to Relay type {0}".format(type(rtype))) - - -def convert(arg): - # type: (Any) -> Expr - """Convert some Python objects into a Relay AST fragment. - - Parameters - ---------- - arg: Any - The Python object - - Returns - ------- - expr: relay.Expr - The converted expression. - """ - if isinstance(arg, Expr): - return arg - elif isinstance(arg, tuple): - return relay.Tuple([convert(el) for el in arg]) - elif isinstance(arg, PartialFunc): - return arg.to_func() - elif isinstance(arg, tvm._ffi.node.NodeGeneric): - return arg.asnode() - else: - value = _convert_to_value(arg) - return Constant(value) - - -class WithScope(object): - """A wrapper for builder methods which introduce scoping.""" - - def __init__(self, enter_value, exit_cb): - self._enter_value = enter_value - self._exit_cb = exit_cb - - def __enter__(self): - return self._enter_value - - def __exit__(self, ptype, value, trace): - if value: - raise value - else: - self._exit_cb() - - -class PartialFunc(object): - """A wrapper around functions while they are being built. - - Used by the builder as a user is building up a function, - allows Function nodes which contain partially initialized - state. - """ - - def __init__(self, params, ret_type, body, type_params): - self.params = params - self.ret_type = ret_type - self.body = body - self.type_params = type_params - - def param_ids(self): - return [p for p in self.params] - - def to_func(self): - """Converts a PartialFunc into a :py:class:`~relay.Function`.""" - return Function( - self.params, - self.body, - self.ret_type, - self.type_params) - -#pylint: disable=invalid-name - - -def _mk_let(bindings, ret_value): - let_expr = ret_value - for var, value in reversed(list(bindings.items())): - let_expr = Let(var, value, let_expr) - return let_expr - - -class IRBuilder(object): - """The IRBuilder class. - - Enables users to build up a Relay environment and program. - - Examples - -------- - - Program: - fn (x : Tensor[f32, (10, 10)]) { - let t1 = log(x); - let t2 = add(t1, x); - return t1; - } - - ..code-block: python - b = IRBuilder() - with b.function(('x', tensor_type(10, 10))) as func: - x, = func.param_ids() - t1 = b.let('t1', log(x)) - t2 = b.let('t2', add(t1, x)) - b.ret(t2) - """ - - def __init__(self): - self.bindings = [OrderedDict({})] - self.scopes = [OrderedDict({})] - self.params = [] - self.ret_values = [None] - self.env = Environment({}) - - def enter_scope(self, params=None): - if not params: - params = [] - - self.bindings.append(OrderedDict({})) - self.scopes.append(OrderedDict({})) - self.params.append(params) - self.ret_values.append(None) - - def exit_scope(self): - bindings = self.bindings.pop() - scopes = self.scopes.pop() - params = self.params.pop() - ret_value = self.ret_values.pop() - return bindings, scopes, params, ret_value - - #pylint: disable=invalid-name - def bind(self, name, value, ty): - lv = Var(name, ty) - self.scopes[-1][name] = lv - self.bindings[-1][lv] = value - return lv - - def let(self, name, value, value_type=None): - if not isinstance(value, Expr): - value = convert(value) - - return self.bind(name, value, value_type) - - def _convert_params(self, raw_params): - relay_params = [] - for raw_param in raw_params: - if isinstance(raw_param, Var): - param = raw_param - elif isinstance(raw_param, tuple): - var, ty = raw_param - ty = _convert_type(ty) - param = Var(var, ty) - elif isinstance(raw_param, str): - param = Var(raw_param, None) - else: - raise Exception("unknown parameter type") - - self.scopes[-1][param.name_hint] = param - relay_params.append(param) - - return relay_params - - def function(self, *params): - """Construct a Relay function.""" - - relay_params = self._convert_params(params) - - self.enter_scope() - - pfunc = PartialFunc(relay_params, None, None, []) - - def _on_exit(): - bindings, _, _, ret_value = self.exit_scope() - body = _mk_let(bindings, ret_value) - pfunc.body = body - - return WithScope(pfunc, _on_exit) - - def ret(self, x): - """Set `x` to be the return value of the current function.""" - if not self.ret_values[-1]: - self.ret_values[-1] = convert(x) - else: - raise Exception( - "return value already set, a function can only have one return value") - - def if_scope(self, cond): - """Construct the if branch an if expression with scoping.""" - self.enter_scope() - - def _on_exit(): - bindings, _, _, ret_value = self.exit_scope() - assert self.ret_values[-1] is None - true_branch = _mk_let(bindings, ret_value) - self.ret_values[-1] = If(cond, true_branch, None) - - return WithScope(10, _on_exit) - - def else_scope(self): - """Construct the else branch of an if expression with scoping.""" - self.enter_scope() - - def _on_exit(): - bindings, _, _, ret_value = self.exit_scope() - partial_if = self.ret_values[-1] - assert isinstance( - partial_if, If) and partial_if.false_branch is None - false_branch = _mk_let(bindings, ret_value) - self.ret_values[-1] = If( - partial_if.cond, - partial_if.true_branch, - false_branch) - - return WithScope(10, _on_exit) - - def param(self, name, ty=None): - if not ty: - ty = scalar_type('float32') - else: - ty = _convert_type(ty) - - return Var(name, ty) - - def global_var(self, name): - # type: (str) -> GlobalVar - """Construct a global var with `name` as its name hint. - - Parameters - ---------- - name: str - The name of the global variable. - - Returns - ------- - global_var: relay.GlobalVar - The global variable with `name`. - - """ - return self.env.global_var(name) - - def decl(self, name, *params, **kwargs): - """Create a global function. - - Parameters - ---------- - name: str or GlobalVar - The name of the function. - params: params - The parameters of the function. - - Returns - ------- - with_scope: Scope for the function. - """ - - ret_type = kwargs.get('ret_type', None) - - self.enter_scope() - - def _on_exit(): - bindings, _, _, ret_value = self.exit_scope() - exp = _mk_let(bindings, ret_value) - self.env.add(name, Function(params, exp, ret_type)) - - return WithScope(10, _on_exit) - - def get(self): - """Get the full program. - - Returns - ---------- - (prog, env) : (relay.Expr, relay.Environment) - A pair of the partial program, and the modified environment. - """ - bindings = self.bindings.pop() - scope = self.scopes.pop() - - if self.bindings: - raise Exception("IRBuilder: binding error") - - if self.scopes: - raise Exception("IRBuilder: scoping error") - - if bindings and scope and not self.ret_values: - raise Exception("IRBuilder: no return value set") - - return _mk_let(bindings, self.ret_values[-1]), self.env - - -def scalar_type(dtype): - """Construct a Relay scalar type. - - Parameters - ---------- - dtype: dtype - The dtype of the scalar type. - - Returns: - scalar_type: relay.Type - The scalar type. - """ - return TensorType(tvm.convert([]), dtype) - - -def tensor_type(*shape, **kwargs): - """Construct a Relay Tensor type. - - Parameters - ---------- - shape: list of tvm.Expr - The shape of the Tensor type. - dtype: dtype - The dtype of the Tensor type. - - Returns - ------- - tensor_type: relay.Type - The resulting tensor types. - """ - dtype = kwargs.get('dtype', 'float32') - - return TensorType(tvm.convert(shape), dtype) - - -def func_type(args, ret_type, type_params=None): - """Construct a Relay function type. - - Parameters - ---------- - args: list of relay.Type - The argument types. - - ret_type: relay.Type - The return type. - - type_params: list of relay.TypeParam - The type parameters. - - Returns - ------- - func_type: The function type. - """ - if not type_params: - type_params = [] - - args = [_convert_type(arg) for arg in args] - ret_type = _convert_type(ret_type) - return FuncType(args, ret_type, type_params, []) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index cbb7095e2f17..549203d12c9f 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -2,37 +2,39 @@ # pylint: disable=unidiomatic-typecheck """The set of passes for Relay. -Exposes an interface for configuring the passes and scripting -them in Python. +Exposes an interface for configuring the passes and +scripting them in Python. """ from . import _ir_pass from . import _make # pylint: disable=invalid-name -def infer_type(env, expr): +def infer_type(expr, env=None): """Infer the type of expr under the context of env. Parameters ---------- - env : relay.Environment + expr: tvm.relay.Expr + The input expression. + + env: Optional[tvm.relay.Environment] The global environment. - expr : relay.Expr - The input expression. Returns ------- - checked_expr : relay.Expr + checked_expr : tvm.relay.Expr The checked expression. """ - return _ir_pass.infer_type(env, expr) + return _ir_pass.infer_type(expr, env) -def well_formed(e): + +def well_formed(expr): """Check that each Var is only bound once (well formed). Parameters ---------- - e: relay.Expr + expr: tvm.relay.Expr The input expression Returns @@ -40,7 +42,8 @@ def well_formed(e): well_form : bool whether the input expression is well formed """ - return _ir_pass.well_formed(e) + return _ir_pass.well_formed(expr) + def check_kind(t, env=None): """Check that the type is well kinded. @@ -48,10 +51,10 @@ def check_kind(t, env=None): Parameters ---------- - t: relay.Type + t: tvm.relay.Type The type to check - env: relay.Environment, optional + env: tvm.relay.Environment, optional The global environment Returns @@ -71,61 +74,65 @@ def check_kind(t, env=None): else: return _ir_pass.check_kind(t) + def free_vars(e): """Get free variables from expression e. Parameters ---------- - e: relay.Expr + e: tvm.relay.Expr The input expression Returns ------- - free : List[relay.Var] - the list of free variables + free : List[tvm.relay.Var] + The list of free variables """ return _ir_pass.free_vars(e) -def free_type_vars(e): + +def free_type_vars(expr): """Get free type variables from expression/type e Parameters ---------- - e: relay.Expr/relay.Type - The input expression/type + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type Returns ------- - free : List[relay.TypeParam] - the list of free type variables + free : List[tvm.relay.TypeParam] + The list of free type variables """ - return _ir_pass.free_type_vars(e) + return _ir_pass.free_type_vars(expr) -def dead_code_elimination(e): + +def dead_code_elimination(expr): """ Remove expressions which does not effect the program result (dead code). Parameters ---------- - e: relay.Expr - The input Expression + e: tvm.relay.Expr + The input Expression Returns ------- - result: relay.Expr - An expression which is semantically equal to the input expression, - but with dead code removed. + result: tvm.relay.Expr + An expression which is semantically equal to the input expression, + but with dead code removed. """ - return _ir_pass.dead_code_elimination(e) + return _ir_pass.dead_code_elimination(expr) + def alpha_equal(lhs, rhs): """Compare two Relay expr for structural equivalence (alpha equivalence). Parameters ---------- - lhs: relay.Expr + lhs: tvm.relay.Expr One of the input Expression. - rhs: relay.Expr + rhs: tvm.relay.Expr One of the input Expression. Returns diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py new file mode 100644 index 000000000000..641566946f58 --- /dev/null +++ b/python/tvm/relay/scope_builder.py @@ -0,0 +1,185 @@ +"""The scope builder interface """ +from __future__ import absolute_import + +from . import expr as _expr +from .._ffi import base as _base + +class WithScope(object): + """A wrapper for builder methods which introduce scoping. + + Parameters + ---------- + enter_value: object + The value returned by enter. + """ + + def __init__(self, enter_value, exit_cb): + self._enter_value = enter_value + self._exit_cb = exit_cb + + def __enter__(self): + return self._enter_value + + def __exit__(self, ptype, value, trace): + if value: + raise value + else: + self._exit_cb() + + +def _make_lets(bindings, ret_value): + """Make a nested let expressions. + + Parameters + ---------- + bindings: List[Tuple[tvm.relay.Var,tvm.relay.Expr]] + The sequence of let bindings + + ret_value: tvm.relay.Expr + The final value of the expression. + + Returns + ------- + lets: tvm.relay.Expr + A nested let expression. + """ + if ret_value is None: + raise RuntimeError("ret is not called in this scope") + if isinstance(ret_value, _expr.If) and ret_value.false_branch is None: + raise RuntimeError("Creating an If expression without else.") + let_expr = ret_value + for var, value in reversed(bindings): + let_expr = _expr.Let(var, value, let_expr) + return let_expr + + +class ScopeBuilder(object): + """Scope builder class. + + Enables users to build up a nested + scope(let, if) expression easily. + + Examples + -------- + ..code-block: python + + sb = relay.ScopeBuilder() + cond = relay.var("cond", 'bool') + x = relay.var("x") + y = relay.var("y") + + with sb.if_scope(cond): + one = relay.const(1, "float32") + t1 = sb.let(t1, relay.add(x, one)) + sb.ret(t1) + with sb.else_scope(): + sb.ret(y) + + print(sb.get().astext()) + """ + def __init__(self): + self._bindings = [[]] + self._ret_values = [None] + + def _enter_scope(self): + self._bindings.append([]) + self._ret_values.append(None) + + def _exit_scope(self): + bindings = self._bindings.pop() + ret_value = self._ret_values.pop() + return bindings, ret_value + + def let(self, var, value): + """Create a new let binding. + + Parameters + ---------- + var: Union[Tuple[str, relay.Type], tvm.relay.Var] + The variable or name of variable. + + value: tvm.relay.Expr + The value to be binded + """ + if isinstance(var, (tuple, list)): + if len(var) > 2: + raise ValueError("Expect var to be Tuple[str, relay.Type]") + var = _expr.var(*var) + elif isinstance(var, _base.string_types): + var = _expr.var(var) + self._bindings[-1].append((var, value)) + return var + + def if_scope(self, cond): + """Create a new if scope. + + Parameters + ---------- + cond: tvm.relay.Expr + The condition + + Returns + ------- + scope: WithScope + The if scope. + + Note + ---- + The user must follows with an else scope. + """ + self._enter_scope() + def _on_exit(): + bindings, ret_value = self._exit_scope() + if self._ret_values[-1] is not None: + raise RuntimeError("result already returned before if scope") + true_branch = _make_lets(bindings, ret_value) + self._ret_values[-1] = _expr.If(cond, true_branch, None) + return WithScope(None, _on_exit) + + def else_scope(self): + """Create a new else scope. + + Returns + ------- + scope: WithScope + The if scope. + """ + self._enter_scope() + + def _on_exit(): + bindings, ret_value = self._exit_scope() + partial_if = self._ret_values[-1] + no_else = (not isinstance(partial_if, _expr.If) or + partial_if.false_branch is not None) + if no_else: + raise RuntimeError("else scope must follows") + false_branch = _make_lets(bindings, ret_value) + self._ret_values[-1] = _expr.If( + partial_if.cond, + partial_if.true_branch, + false_branch) + return WithScope(None, _on_exit) + + def ret(self, value): + """Set the return value of this scope. + + Parameters + ---------- + value: tvm.relay.Expr + The return value. + """ + if self._ret_values[-1] is not None: + raise RuntimeError("ret value is already set in this scope.") + self._ret_values[-1] = value + + def get(self): + """Get the generated result. + + Returns + ------- + value: tvm.relay.Expr + The final result of the expression. + """ + if len(self._bindings) != 1: + raise RuntimeError("can only call get at the outmost scope") + return _make_lets(self._bindings[-1], self._ret_values[-1]) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 34bd60ea08bb..f3c61eec9155 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -122,26 +122,30 @@ class FuncType(Type): We informally write them as: `forall (type_params), (arg_types) -> ret_type where type_constraints` + + Parameters + ---------- + arg_types: List[tvm.relay.Type] + The argument types + + ret_type: tvm.relay.Type + The return type. + + type_params: List[tvm.relay.TypeParam] + The type parameters + + type_constraints: List[tvm.relay.TypeConstraint] + The type constraints. """ def __init__(self, arg_types, ret_type, - type_params, - type_constraints): - """Construct a function type. - - Parameters - ---------- - arg_types: list of Type - ret_type: Type - type_params: list of TypeParam - type_constraints: list of TypeConstraint - - Returns - ------- - func_type: FuncType - The function type. - """ + type_params=None, + type_constraints=None): + if type_params is None: + type_params = [] + if type_constraints is None: + type_constraints = [] self.__init_handle_by_constructor__( _make.FuncType, arg_types, ret_type, type_params, type_constraints) @@ -175,3 +179,21 @@ class TypeRelation(TypeConstraint): def __init__(self, func, args, num_inputs, attrs): self.__init_handle_by_constructor__(_make.TypeRelation, func, args, num_inputs, attrs) + + +def scalar_type(dtype): + """Creates a scalar type. + + This function returns TensorType((), dtype) + + Parameters + ---------- + dtype : str + The content data type. + + Returns + ------- + s_type: tvm.relay.TensorType + The result type. + """ + return TensorType((), dtype) diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 8bda7587f217..2d9180e4597b 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -16,71 +16,60 @@ using namespace runtime; Environment EnvironmentNode::make(tvm::Map global_funcs) { auto n = make_node(); n->functions = std::move(global_funcs); + + for (const auto& kv : n->functions) { + // set gloval var map + CHECK(!n->global_var_map_.count(kv.first->name_hint)) + << "Duplicate global function name " << kv.first->name_hint; + n->global_var_map_.Set(kv.first->name_hint, kv.first); + } return Environment(n); } -GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) { - auto global_id = global_map_.find(str); - if (global_id != global_map_.end()) { - return (*global_id).second; - } else { - auto id = GlobalVarNode::make(str); - this->global_map_.Set(str, id); - return id; - } +GlobalVar EnvironmentNode::GetGlobalVar(const std::string &name) { + auto it = global_var_map_.find(name); + CHECK(it != global_var_map_.end()) + << "Cannot find global var " << name << " in the Environment"; + return (*it).second; } -/*! - * \brief Add a new item to the global environment - * \note if the update flag is not set adding a duplicate - * definition will trigger an exception, otherwise we will - * update the definition if and only if it is type compatible. - */ -void EnvironmentNode::Add(const GlobalVar &var, - const Function &func, +void EnvironmentNode::Add(const GlobalVar& var, + const Function& func, bool update) { // Type check the item before we add it to the environment. auto env = GetRef(this); - - Expr checked_expr = InferType(env, var, func); - - if (const FunctionNode *func_node = checked_expr.as()) { - auto checked_func = GetRef(func_node); - auto type = checked_func->checked_type(); - - CHECK(type.as() == nullptr); - - if (functions.find(var) != functions.end()) { - if (!update) { - throw dmlc::Error("already have definition for XXXX."); - } - - auto old_type = functions[var].as()->checked_type(); - - if (!AlphaEqual(type, old_type)) { - throw dmlc::Error( - "Environment#update changes type, not possible in this mode."); - } - - this->functions.Set(var, checked_func); - } else { - this->functions.Set(var, checked_func); + Function checked_func = InferType(func, env, var); + auto type = checked_func->checked_type(); + CHECK(type.as() == nullptr); + if (functions.find(var) != functions.end()) { + if (!update) { + throw dmlc::Error("already have definition for XXXX."); + } + auto old_type = functions[var].as()->checked_type(); + if (!AlphaEqual(type, old_type)) { + throw dmlc::Error( + "Environment#update changes type, not possible in this mode."); } - } else { - LOG(FATAL) << "internal error: unknown item type, unreachable code"; } + this->functions.Set(var, checked_func); + // set gloval var map + CHECK(!global_var_map_.count(var->name_hint)) + << "Duplicate global function name " << var->name_hint; + global_var_map_.Set(var->name_hint, var); } -void EnvironmentNode::Update(const GlobalVar &var, const Function &func) { +void EnvironmentNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } -void EnvironmentNode::Remove(const GlobalVar & var) { +void EnvironmentNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->data.erase(var.node_); + auto gvar_node = global_var_map_.CopyOnWrite(); + gvar_node->data.erase(var->name_hint); } -Function EnvironmentNode::Lookup(const GlobalVar &var) { +Function EnvironmentNode::Lookup(const GlobalVar& var) { auto func = functions.find(var); if (func != functions.end()) { return (*func).second; @@ -89,14 +78,14 @@ Function EnvironmentNode::Lookup(const GlobalVar &var) { } } -Function EnvironmentNode::Lookup(const std::string &str) { - GlobalVar id = this->GetGlobalVar(str); +Function EnvironmentNode::Lookup(const std::string &name) { + GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } -void EnvironmentNode::Merge(const Environment &env) { +void EnvironmentNode::Update(const Environment &env) { for (auto pair : env->functions) { - this->functions.Set(pair.first, pair.second); + this->Update(pair.first, pair.second); } } @@ -134,10 +123,10 @@ TVM_REGISTER_API("relay._env.Environment_Lookup_str") *ret = env->Lookup(var); }); -TVM_REGISTER_API("relay._env.Environment_Merge") +TVM_REGISTER_API("relay._env.Environment_Update") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; - env->Merge(args[1]); + env->Update(args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 6e3c3454e97b..5bbcb0608e6f 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -217,6 +217,8 @@ class TextPrinter : return ConstScalar(dtype, static_cast(op->data->data)); } else if (dtype == Float(64)) { return ConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Bool()) { + return ConstScalar(dtype, static_cast(op->data->data)); } } // default fall-back, record it as meta node. @@ -638,8 +640,14 @@ class TextPrinter : * \return The corresponding name. */ TextValue AllocVarName(const Var& var) { - std::string name = GetUniqueName('%' + var->name_hint); - TextValue val(name); + std::string name = var->name_hint; + // always make sure first name is alpha + if (name.length() != 0 && !std::isalpha(name[0])) { + name = "%v" + name; + } else { + name = "%" + name; + } + TextValue val(GetUniqueName(name)); CHECK(!memo_.count(var)); memo_[var] = val; return val; diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index e6d60f9344a1..b4984becdf8b 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -78,6 +78,7 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ResizeAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 4a8df2c80ec3..8a7cffd2cd27 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -247,6 +247,8 @@ RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu") // Positional relay function to create LRN operator used by frontend FFI. +TVM_REGISTER_NODE_TYPE(LRNAttrs); + Expr MakeLRN(Expr data, IndexExpr size, IndexExpr axis, @@ -290,6 +292,8 @@ centered at that value (zero padding is added where necessary). // Positional relay function to create L2Normalize operator used by frontend FFI. +TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); + Expr MakeL2Normalize(Expr data, double eps, Array axis) { @@ -315,6 +319,7 @@ Normalizes along dimension axis using an L2 norm - **data**: The input tensor. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.L2NormalizeAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index b67bb96c64a9..da7db042178e 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -77,6 +77,7 @@ RELAY_REGISTER_OP("nn.pad") .describe(R"code(Pad for n-D tensor. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.PadAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 665eaf6de880..8c989ac91237 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -12,6 +12,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); +TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template bool Pool2DRel(const Array& types, @@ -115,6 +116,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") equation. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.MaxPool2DAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) @@ -169,6 +171,7 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.AvgPool2DAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) @@ -232,6 +235,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.GlobalPool2DAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) @@ -261,6 +265,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.GlobalPool2DAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index a429a7c40e82..45bedd73c4c0 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.upsampling") (batch_size, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.UpSamplingAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index d2ec24688633..017ef1e5dfec 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -199,7 +199,7 @@ RELAY_REGISTER_REDUCE_OP("argmax") values over a given axis. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel); @@ -209,7 +209,7 @@ RELAY_REGISTER_REDUCE_OP("argmin") values over a given axis. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ReduceAttrs") .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ea67199f4760..61ee2778d0a2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -144,12 +144,14 @@ RELAY_REGISTER_OP("concatenate") - **axis** : The axis along which the tensors are concatenated. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ConcatenateAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input list of tensors.") .set_support_level(1) .add_type_rel("Concatenate", ConcatenateRel); /* relay.transpose */ +TVM_REGISTER_NODE_TYPE(TransposeAttrs); bool TransposeRel(const Array& types, int num_inputs, @@ -224,12 +226,15 @@ RELAY_REGISTER_OP("transpose") )code" TVM_ADD_FILELINE) .set_num_inputs(1) +.set_attrs_type_key("relay.attrs.TransposeAttrs") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) .add_type_rel("Transpose", TransposeRel); /* relay.reshape */ +TVM_REGISTER_NODE_TYPE(ReshapeAttrs); + bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -310,6 +315,7 @@ Example:: )code" TVM_ADD_FILELINE) .set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ReshapeAttrs") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) .add_type_rel("Reshape", ReshapeRel); @@ -397,12 +403,14 @@ Examples:: [ 4., 3.]] )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.TakeAttrs") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") .set_support_level(2) .add_type_rel("Take", TakeRel); +// Init ops TVM_REGISTER_NODE_TYPE(InitOpAttrs); bool FullRel(const Array& types, @@ -448,6 +456,7 @@ RELAY_REGISTER_OP("full") .describe(R"code(Fill array with scalar value. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.InitOpAttrs") .set_num_inputs(1) .add_argument("fill_value", "double", "The value to fill.") .set_support_level(3) @@ -634,6 +643,10 @@ Examples:: .set_support_level(4) .add_type_rel("Where", WhereRel); + +// Squeeze +TVM_REGISTER_NODE_TYPE(SqueezeAttrs); + Expr MakeSqueeze(Expr data, Array axes) { auto attrs = make_node(); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 169ef35474e2..467c0fcde860 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "./type_relations.h" @@ -21,14 +22,6 @@ TensorType ToTensorType(const Type& t) { } } -// TODO(@jroesch) what size value do we extract, 64bit or 32bit? -int ToInt(const tvm::Expr& e) { - CHECK(e.defined()); - auto imm = e.as(); - CHECK(imm) << "TYPE: " << imm << imm->type << std::endl; - return imm->value; -} - bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -39,72 +32,54 @@ bool IdentityRel(const Array& types, return true; } +bool EqualCheck(const IndexExpr& lhs, + const IndexExpr& rhs) { + IndexExpr diff = lhs - rhs; + if (const int64_t* pdiff = as_const_int(diff)) { + return pdiff[0] == 0; + } + // symbolic + diff = tvm::ir::CanonicalSimplify(diff); + if (const int64_t* pdiff = as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + +bool EqualConstInt(const IndexExpr& lhs, int64_t value) { + if (const int64_t* pvalue = as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { - RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2 - << std::endl; - auto sh1 = t1->shape; - auto sh2 = t2->shape; - RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2 - << std::endl; - if (sh1.size() == 0 && sh2.size() == 0) { - return TensorTypeNode::make({}, output_dtype); - // We have non-zero shapes so broadcast rules apply. - } else { - auto suffix_len = static_cast(std::min(sh1.size(), sh2.size())); - auto full_len = static_cast(std::max(sh1.size(), sh2.size())); - - auto rev_sh1 = sh1.rbegin(); - auto rev_sh2 = sh2.rbegin(); - - while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) { - auto dim1 = ToInt(*rev_sh1); - auto dim2 = ToInt(*rev_sh2); - if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) { - CHECK(false) << "Dimension mistmatch " - << "dim1: " << dim1 << " dim2: " << dim2 << std::endl; - } - rev_sh1++; - rev_sh2++; - } - - Array larger; - Array smaller; - - for (int i = 0; i < (full_len - suffix_len); i++) { - smaller.push_back(make_const(tvm::Int(64), 1)); - } - - if (sh1.size() < sh2.size()) { - for (auto sh : sh1) { - smaller.push_back(sh); - } - larger = sh2; - } else if (sh1.size() > sh2.size()) { - for (auto sh : sh1) { - larger.push_back(sh); - } - smaller = sh2; + std::vector oshape; + size_t ndim1 = t1->shape.size(); + size_t ndim2 = t2->shape.size(); + size_t i = 1; + for (; i <= std::min(ndim1, ndim2); ++i) { + IndexExpr s1 = t1->shape[ndim1 - i]; + IndexExpr s2 = t2->shape[ndim2 - i]; + if (EqualCheck(s1, s2)) { + oshape.push_back(s1); + } else if (EqualConstInt(s1, 1)) { + oshape.push_back(s2); + } else if (EqualConstInt(s2, 1)) { + oshape.push_back(s1); } else { - larger = sh1; - smaller = sh2; + LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2; } - - CHECK_EQ(larger.size(), smaller.size()); - - Array out_shape; - for (size_t i = 0; i < smaller.size(); i++) { - auto left = smaller[i].as(); - auto right = larger[i].as(); - CHECK(left); - CHECK(right); - int64_t dim = std::max(left->value, right->value); - out_shape.push_back(make_const(tvm::Int(64), dim)); - } - - return TensorTypeNode::make(out_shape, output_dtype); } + size_t max_ndim = std::max(ndim1, ndim2); + auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape; + for (; i <= max_ndim; ++i) { + oshape.push_back(rshape[max_ndim - i]); + } + return TensorTypeNode::make(Array( + oshape.rbegin(), oshape.rend()), output_dtype); } bool BroadcastRel(const Array& types, @@ -141,71 +116,5 @@ bool BroadcastCompRel(const Array& types, return false; } -/*! \brief Handle concrete concat case from known input to output. */ -inline Type ConcreteConcatRel(const Type& input_type) { - if (auto tuple_node = input_type.as()) { - // NB: For now the axis argument is hardwired to be 0. - std::vector dims; - DataType dtype; - - CHECK_LT(1, tuple_node->fields.size()); - bool skip_first = true; - - // Collect the suffix dimensions since axis is zero. - // TODO(@jroesch): This is a demonstration of how - // to do varargs. It requires a little more work to - // fully type the behavior of concat. - - auto first = Downcast(tuple_node->fields[0]); - dtype = first->dtype; - - for (auto dim_expr : first->shape) { - if (!skip_first) { - dims.push_back(ToInt(dim_expr)); - } else { - skip_first = false; - } - } - - std::vector axis_dims; - for (auto field_ty : tuple_node->fields) { - auto ttype = Downcast(field_ty); - for (size_t i = 0; i < ttype->shape.size(); i++) { - if (i != 0) { - CHECK_EQ(ToInt(dims[i - 1]), ToInt(ttype->shape[i])); - } else { - axis_dims.push_back(ToInt(ttype->shape[i])); - } - } - } - - auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0); - - Array out_shape = { make_const(Int(64), out_axis_dim) }; - - for (auto dim : dims) { - out_shape.push_back(make_const(Int(64), dim)); - } - - return TensorTypeNode::make(out_shape, dtype); - - } else { - throw TypeRelationError("concat can only be used with a tuple as its argument"); - } -} - -bool ConcatRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - if (types[0].as()) { - reporter->Assign(types[1], ConcreteConcatRel(types[0])); - return true; - } - return false; -} - - } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index f6e94e24caa9..534e917a0b6c 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -13,17 +13,6 @@ namespace tvm { namespace relay { - -/*! \brief The error raised by a type relation. - * - * This error is how a type relation signals that it has failed. - * - */ -struct TypeRelationError : Error { - explicit TypeRelationError(const std::string& msg) - : Error(msg) {} -}; - /*! * \brief The identity type relation, all the types are equal. * @@ -72,22 +61,6 @@ bool BroadcastCompRel(const Array& types, const Attrs& attrs, const TypeReporter& reporter); -/*! - * \brief The concat type relation, implements the concatenating - * rule over the list of input types producing one concatenated - * type. - * - * \param types The input and output types to the relation. - * \param num_inputs The number of input arguments. - * \param attrs The attributes - * \param reporter The reporter. - * \return true whether relation has been resolved. - */ -bool ConcatRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter); - } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 63e75c0bb213..ce069a78186b 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -63,6 +63,7 @@ TVM_REGISTER_API("relay.op.vision._make.multibox_prior") RELAY_REGISTER_OP("vision.multibox_prior") .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." )doc" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.MultiBoxPriorAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(4) diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 5d153c606e63..0d2677e11c67 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -20,7 +20,9 @@ bool IsBoolLit(const Expr& e, bool b) { if (const ConstantNode* c = e.as()) { if (c->is_scalar()) { auto dt = c->tensor_type()->dtype; - if (dt == UInt(8)) { + if (dt == Bool()) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == UInt(8)) { return *reinterpret_cast(c->data->data) == b; } else if (dt == UInt(16)) { return *reinterpret_cast(c->data->data) == b; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1b30865eacb1..3801987c932f 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -28,6 +28,39 @@ namespace tvm { namespace relay { + +// Necessary deferred relation for TupleGetItem +struct TupleGetItemAttrs : public tvm::AttrsNode { + int index; + + TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { + TVM_ATTR_FIELD(index); + } +}; + +bool TupleGetItemRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + if (types[0].as()) return false; + const auto* data = types[0].as(); + CHECK(data != nullptr) + << "TupleGetItem expect input type to be TupleType " + << " get " << types[0] << " instead"; + const auto* param = attrs.as(); + CHECK(param != nullptr); + CHECK_GE(param->index, 0); + CHECK_LT(param->index, data->fields.size()); + reporter->Assign(types[1], data->fields[param->index]); + return true; +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); +TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") +.set_body_typed&, int, const Attrs&, const TypeReporter&)>( + TupleGetItemRel); + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -38,8 +71,7 @@ namespace relay { class TypeInferencer : private ExprFunctor { public: // constructors - TypeInferencer() - : env_(EnvironmentNode::make({})) { + TypeInferencer() { } explicit TypeInferencer(Environment env) : env_(env) { @@ -58,6 +90,8 @@ class TypeInferencer : private ExprFunctor { std::unordered_map type_map_; // The solver used by the inferencer. TypeSolver solver_; + // relation function + TypeRelationFn tuple_getitem_rel_; // Unify two types Type Unify(const Type& t1, const Type& t2, const Span& span) { // TODO(tqchen, jroesch): propagate span to solver @@ -96,6 +130,8 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); + CHECK(env_.defined()) + << "Cannot do type inference without a global variable"; Expr e = env_->Lookup(var); return e->checked_type(); } @@ -116,17 +152,17 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const TupleGetItemNode* op) final { - // TODO(M.K.) - // handle case where field type is not known - Type tuple_type = GetType(op->tuple); - auto tuple_ty_node = tuple_type.as(); - if (!tuple_ty_node) { - LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef(op); - } - if (static_cast(tuple_ty_node->fields.size()) <= op->index) { - LOG(FATAL) << "tuple not big enough" << GetRef(op); + if (!tuple_getitem_rel_.defined()) { + tuple_getitem_rel_ = TypeRelationFn( + EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_); } - return tuple_ty_node->fields[op->index]; + Type tuple_type = GetType(op->tuple); + Type rtype = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + auto attrs = make_node(); + attrs->index = op->index; + solver_.AddConstraint(TypeRelationNode::make( + tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs))); + return rtype; } Type VisitExpr_(const OpNode* op) final { @@ -305,7 +341,6 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } - Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); } @@ -363,20 +398,21 @@ Expr TypeInferencer::Infer(Expr expr) { return Resolver(type_map_, &solver_).VisitExpr(expr); } -Expr InferType(const Environment& env, const Expr& expr) { + +Expr InferType(const Expr& expr, const Environment& env) { return TypeInferencer(env).Infer(expr); } -Expr InferType(const Environment& env, - const GlobalVar& var, - const Function& func) { +Function InferType(const Function& func, + const Environment& env, + const GlobalVar& var) { Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); env->functions.Set(var, func_copy); Expr func_ret = TypeInferencer(env).Infer(func_copy); auto map_node = env->functions.CopyOnWrite(); map_node->data.erase(var.node_); - return func_ret; + return Downcast(func_ret); } TVM_REGISTER_API("relay._ir_pass.infer_type") diff --git a/tests/python/relay/test_ir_builder.py b/tests/python/relay/test_ir_builder.py deleted file mode 100644 index 165c66f17ac3..000000000000 --- a/tests/python/relay/test_ir_builder.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np -from tvm.relay.expr import Let, Constant -from tvm.relay.ir_builder import IRBuilder - -def test_let(): - b = IRBuilder() - x = b.let('x', 1) - b.ret(x) - prog, _ = b.get() - assert isinstance(prog, Let) - var = prog.var - value = prog.value - assert var.name_hint == 'x' - assert var == prog.body - assert isinstance(value, Constant) - assert value.data.asnumpy() == np.array(1) - -if __name__ == "__main__": - test_let() diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index e571f2a9c99a..20c45f5a16c5 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -173,7 +173,7 @@ def test_if(): def test_tuple_get_item(): tup = relay.Var("tuple") get = relay.TupleGetItem(tup, 1) - assert get.tuple == tup + assert get.tuple_value == tup assert get.index == 1 str(get) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 79a4fdd010c5..29814ecc5eb7 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -27,7 +27,7 @@ def test_env(): z = relay.add(z, z) f = relay.Function([x, y], z) env = relay.Environment() - env.add("myf", f) + env["myf"] = f text = env.astext() assert "def @myf" in text assert "%1 = add(%0, %0) # ty=float32" in text @@ -70,15 +70,18 @@ def test_let_if_scope(): x = relay.var("x", "float32") y = relay.var("y", "float32") cond = relay.var("cond", "bool") - v1 = relay.var("v") - v2 = relay.var("v", "float32") - then_branch = relay.Let( - v1, relay.const(1, "float32"), - relay.Let(v2, x, relay.subtract(v1, v2))) - v3 = relay.var("v") - let2 = relay.Let(v3, y, v3) - else_branch = relay.add(let2, let2) - result = relay.If(cond, then_branch, else_branch) + + sb = relay.ScopeBuilder() + with sb.if_scope(cond): + v1 = sb.let("v", relay.const(1, "float32")) + v2 = sb.let("v", x) + sb.ret(relay.subtract(v1, v2)) + with sb.else_scope(): + v3 = relay.var("v") + let2 = relay.Let(v3, y, v3) + sb.ret(relay.add(let2, let2)) + result = sb.get() + f = relay.Function([x, y, cond], result) text = f.astext() assert text.count("{") == 4 @@ -86,10 +89,17 @@ def test_let_if_scope(): show(f.astext()) +def test_variable_name(): + # avoid pure number even if the namehint is pure number + v1 = relay.var("1") + assert "%v1" in v1.astext() + + if __name__ == "__main__": do_print[0] = True - test_let_if_scope() test_func() test_env() test_meta_data() test_call_attrs() + test_let_if_scope() + test_variable_name() diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 914eafeb57a9..5afae6e872d1 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -1,282 +1,143 @@ import tvm import numpy as np from tvm import relay -from tvm.relay.ir_pass import infer_type -from tvm.relay.ir_builder import IRBuilder, func_type -from tvm.relay.ir_builder import scalar_type, convert, tensor_type -from tvm.relay.env import Environment -def assert_has_type(expr, typ, env=Environment({})): - checked_expr = infer_type(env, expr) - checked_type = checked_expr.checked_type - if checked_type != typ: - raise RuntimeError("Type mismatch %s vs %s" % ( - checked_type, typ)) -def test_single_op(): +def test_unary_op(): def check_single_op(opfunc): - "Program: fn (x : float32) { let t1 = f(x); t1 }" - b = IRBuilder() - with b.function(('x', 'float32')) as func: - x, = func.param_ids() - t1 = b.let('t1', opfunc(x)) - b.ret(t1) - assert_has_type(func.to_func(), func_type(['float32'], 'float32')) - - for opfunc in [tvm.relay.log, tvm.relay.exp, tvm.relay.sqrt, - tvm.relay.sigmoid, tvm.relay.tanh]: + tp = relay.TensorType((10, 4), "float32") + x = relay.var("x", tp) + y = opfunc(x) + # test printer + assert ("%0 = {}(%x)".format(y.op.name)) in y.astext() + # test type inference + assert relay.ir_pass.infer_type(y).checked_type == tp + + for opfunc in [tvm.relay.log, + tvm.relay.exp, + tvm.relay.sqrt, + tvm.relay.sigmoid, + tvm.relay.tanh, + relay.nn.relu]: check_single_op(opfunc) +def test_binary_op(): + def check_binary_op(opfunc): + n = tvm.var("n") + t1 = relay.TensorType((5, n, 5)) + t2 = relay.TensorType((n, 1)) + x = relay.var("x", t1) + y = relay.var("y", t2) + z = opfunc(x, y) + # test printer + assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() + assert relay.ir_pass.infer_type(z).checked_type == t1 + + for opfunc in [relay.add, + relay.subtract, + relay.mod, + relay.multiply, + relay.divide]: + check_binary_op(opfunc) + def test_expand_dims_infer_type(): - ib = relay.ir_builder.IRBuilder() n, t, d = tvm.var("n"), tvm.var("t"), 100 - # let's mimic a batch of sequences - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.expand_dims(x, axis=2)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (n, t, 1, 100), "float32") + x = relay.var("x", shape=(n, t, d)) + y = relay.expand_dims(x, axis=2) + assert "axis=2" in y.astext() + checked = relay.ir_pass.infer_type(y) + assert checked.checked_type == relay.TensorType((n, t, 1, 100)) def test_softmax(): - ib = relay.ir_builder.IRBuilder() n, d = tvm.var("n"), tvm.var("d") - x = ib.param("x", relay.ty.TensorType((n, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.softmax(x, axis=1)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, d), "float32") + x = relay.var("x", shape=(n, d)) + y = relay.nn.softmax(x, axis=1) + assert "nn.softmax" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, d)) def test_log_softmax(): - ib = relay.ir_builder.IRBuilder() n, d = tvm.var("n"), tvm.var("d") - x = ib.param("x", relay.ty.TensorType((n, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.log_softmax(x, axis=1)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, d), "float32") - -def test_unary_op(): - for op in [relay.exp, - relay.log, - relay.sqrt, - relay.sigmoid, - relay.nn.relu]: - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.TensorType((10, 4), "int32")) - with ib.function(x) as func: - ib.ret(op(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((10, 4), "int32") - - -def test_binary_op(): - def check_binary_op(opfunc): - """ - Program: - fn (x, y) { - return x y; - } - """ - b = IRBuilder() - - x = b.param('x', tensor_type(5, 5, 5)) - y = b.param('y', tensor_type(5, 5, 5)) - with b.function(x, y) as func: - b.ret(opfunc(x, y)) - b.ret(func) - prog, env = b.get() - ttype = tensor_type(5, 5, 5) - expected_ty = func_type([ttype, ttype], ttype) - assert_has_type(func.to_func(), expected_ty) - - for opfunc in [relay.add, relay.subtract, relay.mod, - relay.multiply, relay.divide]: - check_binary_op(opfunc) - - -def test_binary_broadcast_op(): - def check_binary_broadcast_op(opfunc): - """ - Program: - fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { - return x y; - } - """ - b = IRBuilder() - x = b.param('x', tensor_type(10, 4)) - y = b.param('y', tensor_type(5, 10, 1)) - with b.function(x, y) as func: - b.ret(opfunc(x, y)) - b.ret(func) - prog, env = b.get() - - expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)], - tensor_type(5, 10, 4)) - assert_has_type(func.to_func(), expected_ty) - - for opfunc in [relay.add, relay.subtract, relay.mod, - relay.multiply, relay.divide]: - check_binary_broadcast_op(opfunc) + x = relay.var("x", shape=(n, d)) + y = relay.nn.log_softmax(x, axis=0) + assert "nn.log_softmax" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, d)) def test_concatenate_infer_type(): - ib = relay.ir_builder.IRBuilder() - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - y = ib.param("y", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x, y) as func: - ib.ret(relay.concatenate((x, y), axis=-1)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (n, t, 200), "float32") - - ib = relay.ir_builder.IRBuilder() n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - y = ib.param("y", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x, y) as func: - ib.ret(relay.concatenate((x, y), axis=2)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (n, t, 200), "float32") - - ib = relay.ir_builder.IRBuilder() - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - y = ib.param("y", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x, y) as func: - ib.ret(relay.concatenate((x, y), axis=1)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (n, t + t, 100), "float32") - -def test_lrn(): - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c , h, w), "float32") + x = relay.var("x", shape=(n, t, d)) + y = relay.var("y", shape=(n, t, d)) + z = relay.concatenate((x, y), axis=-1) + assert "axis=" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((n, t, 200)) + z = relay.concatenate((x, y), axis=2) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((n, t, 200)) -def test_l2_normalize(): - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.l2_normalize(x, eps=0.001, axis=[1])) - ib.ret(func) + z = relay.concatenate((x, y), axis=1) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((n, t + t, 100)) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c , h, w), "float32") def test_dropout(): - ib = relay.ir_builder.IRBuilder() - input_ty = relay.ty.TensorType((3, 4, 5), "int8") - x = ib.param("x", input_ty) - with ib.function(x) as func: - ib.ret(relay.nn.dropout(x)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TupleType([input_ty, input_ty]) - - ib = relay.ir_builder.IRBuilder() n, t, d = tvm.var("n"), tvm.var("t"), tvm.var("d") - input_ty = relay.ty.TensorType((n, t, d), "float32") - x = ib.param("x", input_ty) - with ib.function(x) as func: - ib.ret(relay.nn.dropout(x, rate=0.75)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TupleType([input_ty, input_ty]) + input_ty = relay.TensorType((n, t, d), "float32") + x = relay.var("x", input_ty) + y, _ = relay.nn.dropout(x, rate=0.75) + assert "rate=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == input_ty def test_batch_norm(): # beta and gamma ignored - ib = relay.ir_builder.IRBuilder() - data = ib.param("data", relay.ty.TensorType((3, 2, 1), "float32")) - gamma = ib.param("gamma", relay.ty.TensorType((5,), "int8")) - beta = ib.param("beta", relay.ty.TensorType((12, 16), "int64")) - moving_mean = ib.param("moving_mean", relay.ty.TensorType((2,), "float32")) - moving_var = ib.param("moving_var", relay.ty.TensorType((2,), "float32")) - with ib.function(data, gamma, beta, moving_mean, moving_var) as func: - ib.ret(relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - center=False, scale=False)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((3, 2, 1), "float32"), - relay.ty.TensorType((2,), "float32"), - relay.ty.TensorType((2,), "float32") + data = relay.var("data", relay.TensorType((3, 2, 1))) + beta = relay.var("beta", relay.TensorType((2,))) + gamma = relay.var("gamma", relay.TensorType((2,))) + moving_mean = relay.var("moving_mean", relay.TensorType((2,))) + moving_var = relay.var("moving_var", relay.TensorType((2,))) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + center=False, scale=False) + yy = relay.ir_pass.infer_type(y) + assert "center=" in yy.astext() + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ + relay.TensorType((3, 2, 1), "float32"), + relay.TensorType((2,), "float32"), + relay.TensorType((2,), "float32") ])) - # with beta and gamma, different axis - ib = relay.ir_builder.IRBuilder() - data = ib.param("data", relay.ty.TensorType((3, 2, 1), "float32")) - gamma = ib.param("gamma", relay.ty.TensorType((3,), "float32")) - beta = ib.param("beta", relay.ty.TensorType((3,), "float32")) - moving_mean = ib.param("moving_mean", relay.ty.TensorType((3,), "float32")) - moving_var = ib.param("moving_var", relay.ty.TensorType((3,), "float32")) - with ib.function(data, gamma, beta, moving_mean, moving_var) as func: - ib.ret(relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=0, center=False, scale=False)) - ib.ret(func) + beta = relay.var("beta", relay.TensorType((3,))) + gamma = relay.var("gamma", relay.TensorType((3,))) + moving_mean = relay.var("moving_mean", relay.TensorType((3,))) + moving_var = relay.var("moving_var", relay.TensorType((3,))) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TupleType(tvm.convert([ + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + axis=0, center=False, scale=False) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((3, 2, 1), "float32"), relay.ty.TensorType((3,), "float32"), relay.ty.TensorType((3,), "float32") ])) # axis=-1 - ib = relay.ir_builder.IRBuilder() - data = ib.param("data", relay.ty.TensorType((1, 2, 3), "float32")) - gamma = ib.param("gamma", relay.ty.TensorType((3,), "float32")) - beta = ib.param("beta", relay.ty.TensorType((3,), "float32")) - moving_mean = ib.param("moving_mean", relay.ty.TensorType((3,), "float32")) - moving_var = ib.param("moving_var", relay.ty.TensorType((3,), "float32")) - with ib.function(data, gamma, beta, moving_mean, moving_var) as func: - ib.ret(relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, - axis=-1, center=False, scale=False)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TupleType(tvm.convert([ + data = relay.var("data", relay.TensorType((1, 2, 3))) + beta = relay.var("beta", relay.TensorType((3,))) + gamma = relay.var("gamma", relay.TensorType((3,))) + moving_mean = relay.var("moving_mean", relay.TensorType((3,))) + moving_var = relay.var("moving_var", relay.TensorType((3,))) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, + axis=-1, center=False, scale=False) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((1, 2, 3), "float32"), relay.ty.TensorType((3,), "float32"), relay.ty.TensorType((3,), "float32") @@ -285,14 +146,10 @@ def test_batch_norm(): if __name__ == "__main__": test_unary_op() - test_single_op() + test_binary_op() test_expand_dims_infer_type() test_concatenate_infer_type() test_softmax() test_log_softmax() - test_binary_op() - test_binary_broadcast_op() - test_lrn() - test_l2_normalize() test_dropout() test_batch_norm() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 4f37d4893b66..2f32b316924a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -3,162 +3,111 @@ import tvm from tvm import relay + def test_conv2d_infer_type(): # symbolic in batch dimension - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 10, 224, 224 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - w = ib.param("w", relay.ty.IncompleteType()) - - with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d(x, w, - kernel_size=(3, 3), - padding=(1, 1), - channels=2)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32")) + w = relay.var("w") + y = relay.nn.conv2d(x, w, + kernel_size=(3, 3), + padding=(1, 1), + channels=2) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (n, 2, 224, 224), "float32") - assert ftype.arg_types[1] == relay.ty.TensorType( + assert yy.args[1].checked_type == relay.TensorType( (2, 10, 3, 3), "float32") # infer by shape of w, mixed precision - ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), 10, 224, 224 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) - w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8")) - with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d(x, w, out_dtype="int32")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) + w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8")) + y = relay.nn.conv2d(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (n, 2, 222, 222), "int32") # Infer with a different layout - ib = relay.ir_builder.IRBuilder() n, c, h, w = 4, 32, 224, 224 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) - w = ib.param("w", relay.ty.IncompleteType()) - with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d(x, w, - kernel_size=(3, 3), - padding=(1, 1), - channels=16, - data_layout="NCHW4n4c", - weight_layout="OIHW4o4i", - out_dtype="int32")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) + w = relay.var("w") + y = relay.nn.conv2d(x, w, + kernel_size=(3, 3), + padding=(1, 1), + channels=16, + data_layout="NCHW4n4c", + weight_layout="OIHW4o4i", + out_dtype="int32") + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (1, 4, 224, 224, 4, 4), "int32") - assert ftype.arg_types[1] == relay.ty.TensorType( + assert yy.args[1].checked_type == relay.TensorType( (4, 8, 3, 3, 4, 4), "int8") def test_conv2d_transpose_infer_type(): # symbolic in batch dimension - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 10, 10, 12 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - w = ib.param("w", relay.ty.IncompleteType()) - - with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d_transpose(x, w, - kernel_size=(3, 3), - padding=(1, 1), - channels=15)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + w = relay.var("w", relay.IncompleteType()) + y = relay.nn.conv2d_transpose(x, w, + kernel_size=(3, 3), + padding=(1, 1), + channels=15) + assert "channels=15" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (n, 15, 10, 12), "float32") - assert ftype.arg_types[1] == relay.ty.TensorType( + assert yy.args[1].checked_type == relay.TensorType( (10, 15, 3, 3), "float32") # infer by shape of w, mixed precision - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 10, 10, 12 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32")) - with ib.function(x, w) as func: - ib.ret(relay.nn.conv2d_transpose(x, w, - output_padding=(1, 1), - channels=11, - data_layout="NHWC")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32")) + y = relay.nn.conv2d_transpose(x, w, + output_padding=(1, 1), + channels=11, + data_layout="NHWC") + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (n, 15, 15, 11), "float32") def test_upsampling_infer_type(): - ib = relay.ir_builder.IRBuilder() n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, h*2, w*2), "float32") - - ib = relay.ir_builder.IRBuilder() + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") + "method=\"BINLINEAR\"" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32") n, c = tvm.var("n"), tvm.var("c") - x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, 200, 400), "float32") + x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) + y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR") + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") def _test_pool2d_infer_type(opfunc): - ib = relay.ir_builder.IRBuilder() - n, c, h, w = tvm.var("n"), 10, 224, 224 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - with ib.function(x) as func: - ib.ret(opfunc(x, pool_size=(1, 1))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, 10, 224, 224), "float32") - - ph, pw = tvm.var("ph"), tvm.var("pw") - sh, sw = tvm.var("sh"), tvm.var("sw") - - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 10, 224, 224 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - with ib.function(x) as func: - ib.ret(opfunc(x, pool_size=(ph, pw), strides=(sh, sw))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (n, 10, (((224 - ph)/sh) + 1), (((224 - pw)/sw) + 1)), "float32") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = opfunc(x, pool_size=(1, 1)) + assert "pool_size=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32") def _test_global_pool2d_infer_type(opfunc): - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 - x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32")) - with ib.function(x) as func: - ib.ret(opfunc(x, layout="NHWC")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, 1, 1, c), "float32") + x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) + y = opfunc(x, layout="NHWC") + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, 1, 1, c), "float32") - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - with ib.function(x) as func: - ib.ret(opfunc(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, 1, 1), "float32") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = opfunc(x) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, 1, 1), "float32") def test_pool2d_infer_type(): _test_pool2d_infer_type(relay.nn.max_pool2d) @@ -167,101 +116,83 @@ def test_pool2d_infer_type(): _test_global_pool2d_infer_type(relay.nn.global_avg_pool2d) def test_flatten_infer_type(): - ib = relay.ir_builder.IRBuilder() d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") - x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32")) + x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32")) + y = relay.nn.batch_flatten(x) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((d1, ((d2*d3)*d4)), "float32") - with ib.function(x) as func: - ib.ret(relay.nn.batch_flatten(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((d1, ((d2*d3)*d4)), "float32") + x = relay.var("x", relay.TensorType((3, 2, 4, 3), "float32")) + y = relay.nn.batch_flatten(x) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((3, 24), "float32") - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.batch_flatten(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((3, 24), "float32") - - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32")) - with ib.function(x) as func: - ib.ret(relay.nn.batch_flatten(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((d1, ((2*d3)*3)), "float32") + x = relay.var("x", relay.TensorType((d1, 2, d3, 3), "float32")) + y = relay.nn.batch_flatten(x) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((d1, ((2*d3)*3)), "float32") def test_pad_infer_type(): # entirely concrete case - ib = relay.ir_builder.IRBuilder() n, c, h, w = 1, 2, 3, 4 - t = ib.param("t", relay.TensorType((n, c, h, w), "float32")) - with ib.function(t) as func: - ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((3, 6, 9, 12), "float32") + t = relay.var("t", relay.TensorType((n, c, h, w), "float32")) + y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))) + "pad_width=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((3, 6, 9, 12), "float32") # some symbolic values - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") - t = ib.param("t", relay.TensorType((n, c, h, w), "float32")) - with ib.function(t) as func: - ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32") + t = relay.var("t", relay.TensorType((n, c, h, w), "float32")) + y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32") def test_dense_infer_type(): - ib = relay.ir_builder.IRBuilder() n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - - w = ib.param("w", relay.ty.TensorType((w, 2), "float32")) + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + w = relay.var("w", relay.TensorType((w, 2), "float32")) + y = relay.nn.dense(x, w, units=2) + "units=2" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") - with ib.function(x, w) as func: - ib.ret(relay.nn.dense(x, w, units=2)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, h, 2), "float32") - - ib = relay.ir_builder.IRBuilder() n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) wh, ww = tvm.var("wh"), tvm.var("ww") - w = ib.param("w", relay.ty.TensorType((wh, ww), "float32")) + w = relay.var("w", relay.TensorType((wh, ww), "float32")) + y = relay.nn.dense(x, w) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32") - with ib.function(x, w) as func: - ib.ret(relay.nn.dense(x, w)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, h, ww), "float32") - - ib = relay.ir_builder.IRBuilder() n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + w = relay.var("w", relay.IncompleteType()) + y = relay.nn.dense(x, w, units=2) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32") + - w = ib.param("w", relay.ty.IncompleteType()) +def test_lrn(): + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", shape=(n, c , h, w)) + y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75) + "alpha=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c , h, w)) - with ib.function(x, w) as func: - ib.ret(relay.nn.dense(x, w, units=2)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, h, 2), "float32") +def test_l2_normalize(): + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", shape=(n, c , h, w)) + y = relay.nn.l2_normalize(x, eps=0.001, axis=[1]) + "axis=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c , h, w)) if __name__ == "__main__": + test_lrn() + test_l2_normalize() test_conv2d_infer_type() test_pool2d_infer_type() test_upsampling_infer_type() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 0605ac02339b..d1bff2940457 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -3,154 +3,92 @@ import tvm import numpy as np from tvm import relay -from tvm.relay.ir_pass import infer_type -from tvm.relay.ir_builder import IRBuilder, func_type -from tvm.relay.env import Environment from nose.tools import raises def test_zeros_ones(): for op in [relay.zeros, relay.ones]: - ib = relay.ir_builder.IRBuilder() - with ib.function() as func: - ib.ret(op((124, 50), "float64")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((124, 50), "float64") - + y = op(shape=(124, 50), dtype="float64") + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((124, 50), "float64") def test_unary_identity(): - for op in [relay.zeros_like, relay.ones_like]: - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.TensorType((8, 9, 4), "int32")) - with ib.function(x) as func: - ib.ret(op(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32") + for op in [relay.zeros_like, + relay.ones_like, + relay.ceil, + relay.floor, + relay.trunc, + relay.round, + relay.abs, + relay.copy, + relay.negative]: + x = relay.var("x", relay.TensorType((8, 9, 4), "float32")) + y = op(x) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((8, 9, 4), "float32") def test_clip_type(): - ib = relay.ir_builder.IRBuilder() - a = ib.param("a", relay.TensorType((10, 4), "float32")) - with ib.function(a) as func: - ib.ret(relay.clip(a, 1., 4.)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((10, 4), "float32") - - -def test_copy_infer_type(): - ib = relay.ir_builder.IRBuilder() - n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.copy(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (n, t, 100), "float32") + a = relay.var("a", relay.TensorType((10, 4), "float32")) + y = relay.clip(a, 1., 4.) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((10, 4), "float32") def test_transpose_infer_type(): - ib = relay.ir_builder.IRBuilder() n, t, d = tvm.var("n"), tvm.var("t"), 100 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.transpose(x, axes=(1, 0, 2))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, t, d), "float32")) + y = relay.transpose(x, axes=(1, 0, 2)) + "axes=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (t, n, 100), "float32") -def test_squeeze_default_axes_infer_type(): - ib = relay.ir_builder.IRBuilder() +def test_squeeze_infer_type(): n, t, d = 1, 4, 1 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.squeeze(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (4,), "float32") - + x = relay.var("x", relay.TensorType((n, t, d), "float32")) + y = relay.squeeze(x, axes=(2,)) + assert "axes=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( + (1, 4), "float32") -def test_squeeze_axes_infer_type(): - ib = relay.ir_builder.IRBuilder() n, t, d = 1, 4, 1 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.squeeze(x, axes=(2,))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( - (1, 4), "float32") + x = relay.var("x", relay.TensorType((n, t, d), "float32")) + y = relay.squeeze(x) + assert "axes=" not in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( + (4,), "float32") @raises(tvm._ffi.base.TVMError) def test_squeeze_bad_axes_infer_type(): - ib = relay.ir_builder.IRBuilder() n, t, d = 1, 4, 1 - x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) - with ib.function(x) as func: - ib.ret(relay.squeeze(x, axes=(1,))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type + x = relay.var("x", relay.TensorType((n, t, d), "float32")) + y = relay.squeeze(x, axes=(1,)) + yy = relay.ir_pass.infer_type(y) def test_reshape_infer_type(): - ib = relay.ir_builder.IRBuilder() n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20 - x = ib.param("x", relay.ty.TensorType((n, t, d1, d2), "float32")) - with ib.function(x) as func: - ib.ret(relay.reshape(x, newshape=(n, t, 2000))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, t, d1, d2), "float32")) + y = relay.reshape(x, newshape=(n, t, 2000)) + assert "newshape=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( (n, t, 2000), "float32") -def assert_has_type(expr, typ, env=Environment({})): - checked_expr = infer_type(env, expr) - checked_type = checked_expr.checked_type - if checked_type != typ: - raise RuntimeError("Type mismatch %s vs %s" % ( - checked_type, typ)) - -def test_single_op(): - def check_single_op(opfunc): - "Program: fn (x : float32) { let t1 = f(x); t1 }" - b = IRBuilder() - with b.function(('x', 'float32')) as func: - x, = func.param_ids() - t1 = b.let('t1', opfunc(x)) - b.ret(t1) - assert_has_type(func.to_func(), func_type(['float32'], 'float32')) - - for opfunc in [tvm.relay.ceil, tvm.relay.floor, tvm.relay.trunc, - tvm.relay.round, tvm.relay.abs, tvm.relay.negative]: - check_single_op(opfunc) def test_take_infer_type(): def verify_take(dshape, indices_shape, oshape, axis=None): - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.ty.TensorType(dshape, "float32")) - indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32")) - with ib.function(x, indices) as func: - ib.ret(relay.take(x, indices, axis=axis)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType(oshape, "float32") + x = relay.var("x", relay.TensorType(dshape, "float32")) + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + y = relay.take(x, indices, axis=axis) + y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType(oshape, "float32") d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3") d4, d5, d6 = tvm.var("d4"), tvm.var("d5"), tvm.var("d6") @@ -164,73 +102,52 @@ def verify_take(dshape, indices_shape, oshape, axis=None): def test_full(): # default settings: match input dtype - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.TensorType((), "int8")) - with ib.function(x) as func: - ib.ret(relay.full(x, ())) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((), "int8") + x = relay.var("x", relay.TensorType((), "int8")) + y = relay.full(x, ()) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((), "int8") # change the shape and dtype - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.TensorType((), "float32")) - with ib.function(x) as func: - ib.ret(relay.full(x, (1, 2), "int8")) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((1, 2), "int8") + x = relay.var("x", relay.TensorType((), "float32")) + y = relay.full(x, (1, 2), "int8") + "shape=" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((1, 2), "int8") def test_full_like(): # concrete shape - ib = relay.ir_builder.IRBuilder() - base = ib.param("base", relay.TensorType((1, 2, 3), "float32")) - fill = ib.param("fill", relay.TensorType((), "float32")) - with ib.function(base, fill) as func: - ib.ret(relay.full_like(base, fill)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((1, 2, 3), "float32") + base = relay.var("base", relay.TensorType((1, 2, 3), "float32")) + fill = relay.var("fill", relay.TensorType((), "float32")) + y = relay.full_like(base, fill) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((1, 2, 3), "float32") # symbolic shape - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") - base = ib.param("base", relay.TensorType((n, c, h, w), "float32")) - fill = ib.param("fill", relay.TensorType((), "float32")) - with ib.function(base, fill) as func: - ib.ret(relay.full_like(base, fill)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((n, c, h, w), "float32") + base = relay.var("base", relay.TensorType((n, c, h, w), "float32")) + fill = relay.var("fill", relay.TensorType((), "float32")) + y = relay.full_like(base, fill) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") def test_infer_type_leaky_relu(): - ib = relay.ir_builder.IRBuilder() n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - - with ib.function(x) as func: - ib.ret(relay.nn.leaky_relu(x, alpha=0.1)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = relay.nn.leaky_relu(x, alpha=0.1) + "alpha=0.1" in y.astext() + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, w), "float32") if __name__ == "__main__": - test_single_op() test_zeros_ones() test_unary_identity() test_clip_type() - test_copy_infer_type() test_transpose_infer_type() test_reshape_infer_type() test_take_infer_type() test_full() test_full_like() test_infer_type_leaky_relu() - test_squeeze_axes_infer_type() - test_squeeze_default_axes_infer_type() + test_squeeze_infer_type() + test_squeeze_bad_axes_infer_type() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index dea300422e45..c2b685affab4 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -1,66 +1,24 @@ import tvm import numpy as np from tvm import relay -from tvm.relay.ir_pass import infer_type -from tvm.relay.ir_builder import IRBuilder, func_type -from tvm.relay.ir_builder import scalar_type, convert, tensor_type -from tvm.relay.env import Environment - -def assert_has_type(expr, typ, env=Environment({})): - checked_expr = infer_type(env, expr) - checked_type = checked_expr.checked_type - if checked_type != typ: - raise RuntimeError("Type mismatch %s vs %s" % ( - checked_type, typ)) def test_binary_op(): def check_binary_op(opfunc): - """ - Program: - fn (x, y) { - return x y; - } - """ - b = IRBuilder() - - x = b.param('x', tensor_type(5, 5, 5)) - y = b.param('y', tensor_type(5, 5, 5)) - with b.function(x, y) as func: - b.ret(opfunc(x, y)) - b.ret(func) - prog, env = b.get() - ttype = tensor_type(5, 5, 5) - expected_ty = func_type([ttype, ttype], ttype) - assert_has_type(func.to_func(), expected_ty) + n = tvm.var("n") + t1 = relay.TensorType((5, n, 5)) + t2 = relay.TensorType((n, 1)) + x = relay.var("x", t1) + y = relay.var("y", t2) + z = opfunc(x, y) + # test printer + assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() + assert relay.ir_pass.infer_type(z).checked_type == t1 for opfunc in [relay.pow]: check_binary_op(opfunc) -def test_binary_broadcast_op(): - def check_binary_broadcast_op(opfunc): - """ - Program: - fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { - return x y; - } - """ - b = IRBuilder() - x = b.param('x', tensor_type(10, 4)) - y = b.param('y', tensor_type(5, 10, 1)) - with b.function(x, y) as func: - b.ret(opfunc(x, y)) - b.ret(func) - prog, env = b.get() - - expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)], - tensor_type(5, 10, 4)) - assert_has_type(func.to_func(), expected_ty) - - for opfunc in [relay.pow]: - check_binary_broadcast_op(opfunc) - def test_cmp_type(): for op in (relay.greater, relay.greater_equal, @@ -68,138 +26,59 @@ def test_cmp_type(): relay.less_equal, relay.equal, relay.not_equal): - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.TensorType((10, 4), "float32")) - y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) - with ib.function(x, y) as func: - ib.ret(op(x, y)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1") + x = relay.var("x", relay.TensorType((10, 4), "float32")) + y = relay.var("y", relay.TensorType((5, 10, 1), "float32")) + z = op(x, y) + z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((5, 10, 4), "bool") + -def test_binary_broadcast(): +def test_binary_int_broadcast(): for op in [relay.right_shift, relay.left_shift, relay.maximum, relay.minimum]: - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.TensorType((10, 4), "int32")) - y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) - with ib.function(x, y) as func: - ib.ret(op(x, y)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") - -def test_argmax(): - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(1,))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(2,), keepdims=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(2,), keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32") - -def test_argmin(): - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmax(x, axis=(1,))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, h, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=(2,), keepdims=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c , 1, w), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=(2,), keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, 1 , h, 1), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=(2,1), keepdims=True, exclude=True)) - ib.ret(func) - - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, c , h, 1), "int32") - - ib = relay.ir_builder.IRBuilder() - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.argmin(x, axis=None, keepdims=True, exclude=True)) - ib.ret(func) + x = relay.var("x", relay.TensorType((10, 4), "int32")) + y = relay.var("y", relay.TensorType((5, 10, 1), "int32")) + z = op(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((5, 10, 4), "int32") + + +def test_arg_reduce(): + for op in [relay.argmax, relay.argmin]: + n, c , h, w = 10, 20, 3, 4 + x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32")) + z = relay.argmax(x, axis=(1,)) + "axis=" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType((n, h, w), "int32") + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32")) + z = relay.argmax(x, axis=(2,), keepdims=True) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType((n, c , 1, w), "int32") + + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32")) + z = relay.argmax(x, axis=(2,), keepdims=True, exclude=True) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType((1, 1 , h, 1), "int32") - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((1, 1 , 1, 1), "int32") def test_where(): - ib = relay.ir_builder.IRBuilder() - cond = ib.param("cond", relay.TensorType((3, 4), "float32")) - x = ib.param("x", relay.TensorType((3, 4), "float32")) - y = ib.param("y", relay.TensorType((3, 4), "float32")) - with ib.function(cond, x, y) as func: - ib.ret(relay.where(cond, x, y)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.TensorType((3, 4), "float32") + cond = relay.var("cond", relay.TensorType((3, 4), "float32")) + x = relay.var("x", relay.TensorType((3, 4), "float32")) + y = relay.var("y", relay.TensorType((3, 4), "float32")) + z = relay.where(cond, x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((3, 4), "float32") if __name__ == "__main__": test_binary_op() - test_binary_broadcast_op() test_cmp_type() - test_binary_broadcast() + test_binary_int_broadcast() test_where() - test_multibox_prior() - test_argmax() - test_argmin() + test_arg_reduce() diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e04bd9bab91a..4e554cd0cf81 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -4,26 +4,18 @@ from tvm import relay def test_resize_infer_type(): - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) + x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) th, tw = tvm.var("th"), tvm.var("tw") + z = relay.image.resize(x, (th, tw)) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") - with ib.function(x) as func: - ib.ret(relay.image.resize(x, (th, tw))) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, th, tw), "int8") - - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) - with ib.function(x) as func: - ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType((n, c, 100, 200), "int8") + x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) + z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False) + assert "size=" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @@ -34,29 +26,21 @@ def test_multibox_prior(): offsets = (0.2, 0.3) clip = True - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 3, 56, 56 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - with ib.function(x) as func: - ib.ret(relay.vision.multibox_prior(x, sizes, ratios, - steps, offsets, clip)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + z = relay.vision.multibox_prior(x, sizes, ratios, + steps, offsets, clip) + assert "sizes=" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType( (1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32") - ib = relay.ir_builder.IRBuilder() n, c, h, w = tvm.var("n"), 24, 32, 32 - x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) - - with ib.function(x) as func: - ib.ret(relay.vision.multibox_prior(x)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == relay.ty.TensorType( + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + z = relay.vision.multibox_prior(x) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType( (1, h * w, 4), "float32") diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 51c1d4a2715a..18959687a0fd 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -2,7 +2,6 @@ import numpy as np from tvm import relay from tvm.relay.ir_pass import alpha_equal -from tvm.relay.ir_builder import convert def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") @@ -164,11 +163,11 @@ def test_type_relation_alpha_equal(): def test_constant_alpha_equal(): - x = convert(1) - y = convert(2) + x = relay.const(1) + y = relay.const(2) assert alpha_equal(x, x) assert not alpha_equal(x, y) - assert alpha_equal(x, convert(1)) + assert alpha_equal(x, relay.const(1)) def test_var_alpha_equal(): @@ -180,9 +179,9 @@ def test_var_alpha_equal(): assert not alpha_equal(v1, v2) # let node allows for setting the eq_map - l1 = relay.Let(v1, convert(1), v1) - l2 = relay.Let(v2, convert(1), v2) - l3 = relay.Let(v1, convert(1), v2) + l1 = relay.Let(v1, relay.const(1), v1) + l2 = relay.Let(v2, relay.const(1), v2) + l3 = relay.Let(v1, relay.const(1), v2) assert alpha_equal(l1, l2) assert not alpha_equal(l1, l3) @@ -204,34 +203,34 @@ def test_tuple_alpha_equal(): # unit value is a valid tuple assert alpha_equal(relay.Tuple([]), relay.Tuple([])) - tup = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])]) - same = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])]) + tup = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) + same = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) assert alpha_equal(tup, same) # use the eq_map let_tup = relay.Let(v1, tup, v1) - let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3), - relay.Tuple([convert(4)])]), + let_mapped = relay.Let(v2, relay.Tuple([v2, relay.const(2), relay.const(3), + relay.Tuple([relay.const(4)])]), v2) assert alpha_equal(let_tup, let_mapped) - more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2]) + more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2]) assert not alpha_equal(tup, more_fields) - fewer_fields = relay.Tuple([v1, convert(2), convert(3)]) + fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)]) assert not alpha_equal(tup, fewer_fields) - different_end = relay.Tuple([v1, convert(2), convert(3), - relay.Tuple([convert(5)])]) + different_end = relay.Tuple([v1, relay.const(2), relay.const(3), + relay.Tuple([relay.const(5)])]) assert not alpha_equal(tup, different_end) - different_start = relay.Tuple([v2, convert(2), convert(3), - relay.Tuple([convert(4)])]) + different_start = relay.Tuple([v2, relay.const(2), relay.const(3), + relay.Tuple([relay.const(4)])]) assert not alpha_equal(tup, different_start) - longer_at_end = relay.Tuple([v1, convert(2), convert(3), - relay.Tuple([convert(4), convert(5)])]) + longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3), + relay.Tuple([relay.const(4), relay.const(5)])]) assert not alpha_equal(tup, longer_at_end) @@ -319,11 +318,11 @@ def test_call_alpha_equal(): tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((), "int8") - basic_args = [convert(1), convert(2), v2, relay.Tuple([])] + basic_args = [relay.const(1), relay.const(2), v2, relay.Tuple([])] # manually writing out args to ensure that args does not rely on # pointer equality - call = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([])], + call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])], attr1, [tt1]) same = relay.Call(v1, basic_args, attr1, [tt1]) assert alpha_equal(call, same) @@ -331,19 +330,19 @@ def test_call_alpha_equal(): different_fn = relay.Call(v2, basic_args, attr1, [tt1]) assert not alpha_equal(call, different_fn) - fewer_args = relay.Call(v1, [convert(1), convert(2), v2], attr1, [tt1]) + fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1]) assert not alpha_equal(call, fewer_args) - reordered_args = relay.Call(v1, [convert(2), convert(1), + reordered_args = relay.Call(v1, [relay.const(2), relay.const(1), relay.Tuple([]), v2], attr1, [tt1]) assert not alpha_equal(call, reordered_args) - different_args = relay.Call(v1, [convert(1), convert(2), convert(3)], + different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)], attr1, [tt1]) assert not alpha_equal(call, different_args) - more_args = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([]), - convert(3), convert(4)], attr1, [tt1]) + more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]), + relay.const(3), relay.const(4)], attr1, [tt1]) assert not alpha_equal(call, more_args) different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) @@ -367,27 +366,27 @@ def test_let_alpha_equal(): v2 = relay.Var("v2") v3 = relay.Var("v3") - let = relay.Let(v1, convert(2), v1) - mapped = relay.Let(v2, convert(2), v2) + let = relay.Let(v1, relay.const(2), v1) + mapped = relay.Let(v2, relay.const(2), v2) assert alpha_equal(let, mapped) - mismatched_var = relay.Let(v2, convert(2), v3) + mismatched_var = relay.Let(v2, relay.const(2), v3) assert not alpha_equal(let, mismatched_var) - different_value = relay.Let(v2, convert(3), v2) + different_value = relay.Let(v2, relay.const(3), v2) assert not alpha_equal(let, different_value) - different_body = relay.Let(v2, convert(3), convert(12)) + different_body = relay.Let(v2, relay.const(3), relay.const(12)) assert not alpha_equal(let, different_body) # specified types must match - let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype) - same_type = relay.Let(v1_wtype, convert(2), v1_wtype) + let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype) + same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype) assert alpha_equal(let_with_type, same_type) assert not alpha_equal(let, let_with_type) v2 = relay.Var("v1", tt2) - different_type = relay.Let(v2, convert(2), v2) + different_type = relay.Let(v2, relay.const(2), v2) assert not alpha_equal(let_with_type, different_type) @@ -395,17 +394,17 @@ def test_if_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") - if_sample = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)])) - same = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)])) + if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) + same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) assert alpha_equal(if_sample, same) - different_cond = relay.If(v2, convert(1), relay.Tuple([convert(2), convert(3)])) + different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) assert not alpha_equal(if_sample, different_cond) - different_true = relay.If(v1, convert(2), relay.Tuple([convert(2), convert(3)])) + different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)])) assert not alpha_equal(if_sample, different_true) - different_false = relay.If(v1, convert(1), relay.Tuple([])) + different_false = relay.If(v1, relay.const(1), relay.Tuple([])) assert not alpha_equal(if_sample, different_false) diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index c4bacce3ddfc..f74aaf74e474 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -1,7 +1,6 @@ import tvm from tvm import relay from tvm.relay.ir_pass import dead_code_elimination, alpha_equal -from tvm.relay.ir_builder import convert, IRBuilder from tvm.relay.op import log, add, equal, subtract @@ -19,9 +18,9 @@ def __init__(self): self.tt = relay.TensorType(self.shape, "float32") self.int32 = relay.TensorType([], "int32") self.float32 = relay.TensorType([], "float32") - self.one = convert(1.0) - self.two = convert(2.0) - self.three = convert(3.0) + self.one = relay.const(1.0) + self.two = relay.const(2.0) + self.three = relay.const(3.0) e = env() @@ -58,9 +57,12 @@ def test_recursion(): f = relay.Var("f") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) - funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) + funcbody = relay.If(equal(n, relay.const(0)), + data, + relay.Call(f, [subtract(n, relay.const(1.0)), + log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) - orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0))) + orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) @@ -70,8 +72,10 @@ def test_op_let(): def test_if(): - orig = relay.If(convert(True), e.a, e.b) - assert alpha_equal(dead_code_elimination(orig), e.a) + cond = relay.const(True) + orig = relay.If(cond, e.a, e.b) + y = dead_code_elimination(orig) + assert alpha_equal(y, e.a) def test_tuple_get_item(): @@ -82,10 +86,10 @@ def test_tuple_get_item(): if __name__ == "__main__": + test_if() test_let() test_used_let() test_chain_unused_let() test_recursion() test_op_let() - test_if() test_tuple_get_item() diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 77b04590df59..2d8f98974639 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -4,34 +4,17 @@ import tvm import numpy as np from tvm.relay.ir_pass import infer_type -from tvm.relay.ir_builder import IRBuilder, func_type -from tvm.relay.ir_builder import scalar_type, convert, tensor_type -from tvm.relay.env import Environment -from tvm.relay.op import log, add, equal, subtract, concatenate -from tvm.relay.expr import Function from tvm import relay -def assert_has_type(expr, typ, env=Environment({})): - checked_expr = infer_type(env, expr) - checked_type = checked_expr.checked_type - if checked_type != typ: - raise RuntimeError("Type mismatch %s vs %s" % ( - checked_type, typ)) - - -def assert_decl_has_type(env, name, typ): - func = env[name] - assert func.checked_type == typ - def test_monomorphic_let(): "Program: let x = 1; return x" - b = IRBuilder() - x = b.let('x', 1.0, value_type=scalar_type('float64')) - b.ret(x) + sb = relay.ScopeBuilder() + x = sb.let('x', relay.const(1.0, "float64")) + sb.ret(x) + xchecked = relay.ir_pass.infer_type(sb.get()) + assert xchecked.checked_type == relay.scalar_type("float64") - prog, env = b.get() - assert_has_type(prog, scalar_type('float64')) def test_dual_op(): """Program: @@ -41,31 +24,29 @@ def test_dual_op(): return t1; } """ - b = IRBuilder() - with b.function(('x', tensor_type(10, 10))) as func: - x, = func.param_ids() - t1 = b.let('t1', log(x)) - t2 = b.let('t2', add(t1, x)) - b.ret(t2) - - assert_has_type(func.to_func(), - func_type([tensor_type(10, 10)], tensor_type(10, 10))) + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", relay.log(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + fchecked = relay.ir_pass.infer_type(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) def test_decl(): """Program: - def f(x : Tensor[f32, (10, 10)]) { - let lx = log(x); - return lx; + def f(x : Tensor[(10, 10), f32]) { + return log(x); } """ - b = IRBuilder() - x = b.param('x') - with b.decl('f', x): - lx = b.let('lx', log(x)) - b.ret(lx) - _, env = b.get() - assert_decl_has_type(env, 'f', func_type(['float32'], 'float32')) + sb = relay.ScopeBuilder() + tp = relay.TensorType((10, 10)) + x = relay.var("x", tp) + f = relay.Function([x], relay.log(x)) + fchecked = relay.ir_pass.infer_type(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) def test_recursion(): @@ -78,54 +59,44 @@ def f(n: i32, data: f32) -> f32 { return f(n - 1, log(data)); } } - f(2, 10000); """ - b = IRBuilder() - f = b.global_var('f') - n = b.param('n', ty='int32') - data = b.param('data', ty='float32') - with b.decl(f, n, data): - with b.if_scope(equal(n, convert(0))): - b.ret(data) - with b.else_scope(): - b.ret(f(subtract(n, convert(1)), log(data))) - b.ret(f(convert(2.0), convert(10000.0))) - assert_decl_has_type(b.env, 'f', func_type( - ['int32', 'float32'], 'float32')) - # TODO(@jroesch): need evaluator or new runtime - # to execute this. + sb = relay.ScopeBuilder() + f = relay.GlobalVar("f") + ti32 = relay.scalar_type("int32") + tf32 = relay.scalar_type("float32") + n = relay.var("n", ti32) + data = relay.var("data", tf32) + + with sb.if_scope(relay.equal(n, relay.const(0, ti32))): + sb.ret(data) + with sb.else_scope(): + sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) + env = relay.Environment() + env[f] = relay.Function([n, data], sb.get()) + assert "%3 = @f(%1, %2)" in env.astext() + assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32) -def test_concat(): - """ - Program: - def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { - return concatenate((x, y), axis=0); - } - """ - ib = IRBuilder() - try_concat2 = ib.global_var('try_concat2') - x = ib.param('x', ty=tensor_type(3, 2)) - y = ib.param('y', ty=tensor_type(2, 2)) - with ib.decl(try_concat2, x, y): - ib.ret(concatenate((x, y), axis=0)) - fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2)) - assert_decl_has_type(ib.env, try_concat2, fn_ty) def test_tuple(): - ib = IRBuilder() - dup = ib.global_var('dup') - x = ib.param('x') - with ib.decl(dup, x): - ib.ret(relay.Tuple([x, x])) - # todo: why is this not generalized? - fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()])) - assert_decl_has_type(ib.env, dup, fn_ty) + tp = relay.TensorType((10,)) + x = relay.var("x", tp) + res = relay.Tuple([x, x]) + assert (relay.ir_pass.infer_type(res).checked_type == + relay.TupleType([tp, tp])) + + +def test_free_expr(): + x = relay.var("x", "float32") + y = relay.add(x, x) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.scalar_type("float32") + if __name__ == "__main__": + test_free_expr() test_dual_op() test_recursion() test_monomorphic_let() test_decl() test_recursion() - test_concat() test_tuple() diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index c96ca59d2c8d..e8ff67756931 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -1,7 +1,5 @@ import tvm - from tvm import relay -from tvm.relay.ir_builder import scalar_type, convert, tensor_type def make_rel(name, args, num_inputs=None, attrs=None): From 9a8d2c062347ee97a26b61163a9d53f223fa1aed Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 19 Oct 2018 13:18:24 -0700 Subject: [PATCH 2/2] Rename TypeParam->TypeVar for consistency --- include/tvm/relay/expr.h | 4 +- include/tvm/relay/op.h | 6 +-- include/tvm/relay/pass.h | 4 +- include/tvm/relay/type.h | 32 +++++++-------- python/tvm/relay/__init__.py | 2 +- python/tvm/relay/ty.py | 10 ++--- src/relay/ir/environment.cc | 23 +++++------ src/relay/ir/expr.cc | 2 +- src/relay/ir/expr_functor.cc | 4 +- src/relay/ir/type.cc | 22 +++++------ src/relay/pass/alpha_eq.cc | 10 ++--- src/relay/pass/kind_check.cc | 4 +- src/relay/pass/let_list.h | 2 +- src/relay/pass/type_functor.h | 4 +- src/relay/pass/type_infer.cc | 10 ++--- src/relay/pass/type_subst.cc | 12 +++--- src/relay/pass/type_subst.h | 4 +- src/relay/pass/type_visitor.h | 12 +++--- src/relay/pass/util.cc | 24 +++++------ tests/python/relay/test_ir_nodes.py | 6 +-- tests/python/relay/test_pass_alpha_equal.py | 26 ++++++------ tests/python/relay/test_pass_check_kind.py | 44 ++++++++++----------- tests/python/relay/test_pass_free_vars.py | 2 +- 23 files changed, 132 insertions(+), 137 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index d0b58e0213c7..142982d48907 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -197,7 +197,7 @@ class FunctionNode : public ExprNode { * * \note This can be usually empty for non-polymorphic functions. */ - tvm::Array type_params; + tvm::Array type_params; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("params", ¶ms); @@ -219,7 +219,7 @@ class FunctionNode : public ExprNode { TVM_DLL static Function make(tvm::Array params, Expr body, Type ret_type, - tvm::Array ty_params); + tvm::Array ty_params); static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 4dcff22b84e8..fe6d957e79ed 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -371,14 +371,14 @@ inline OpRegistry& OpRegistry::add_type_rel( env_type_rel_func = env_func; } - Array type_params; + Array type_params; Array arg_types; // Add inputs. std::string input_name_prefix = "in"; for (int i = 0; i < get()->num_inputs; i++) { auto name = input_name_prefix + std::to_string(i); - auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType); + auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType); type_params.push_back(param); arg_types.push_back(param); } @@ -386,7 +386,7 @@ inline OpRegistry& OpRegistry::add_type_rel( Array ty_call_args = arg_types; // Add output type. - auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType); + auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType); type_params.push_back(out_param); // this will trigger copy on write. ty_call_args.push_back(out_param); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 04f6a1842ee6..9a3b75364167 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -120,7 +120,7 @@ tvm::Array FreeVariables(const Expr& e); * * \return the set of free type variables. */ -tvm::Array FreeTypeVariables(const Expr& e); +tvm::Array FreeTypeVariables(const Expr& e); /*! \brief Get free type parameters from type t. * @@ -130,7 +130,7 @@ tvm::Array FreeTypeVariables(const Expr& e); * * \return the set of free type variables. */ -tvm::Array FreeTypeVariables(const Type& t); +tvm::Array FreeTypeVariables(const Type& t); /*! \brief Remove expressions which does not effect the program result. * diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 9a91bd09c70e..2bb9b3070270 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); * This can be viewed as template parameter in c++ template function. * * For example, in the following pesudo code, - * the TypeParam of f is TypeParam(kind=kShapeVar, var=n). + * the TypeVar of f is TypeVar(kind=kShapeVar, var=n). * This function can take in a Tensor with shape=(3, 3) and * returns a Tensor with shape=(9,) * @@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] * * \endcode - * \sa TypeParamNode The actual container class of TypeParam + * \sa TypeVarNode The actual container class of TypeVar */ -class TypeParam; -/*! \brief TypeParam container node */ -class TypeParamNode : public TypeNode { +class TypeVar; +/*! \brief TypeVar container node */ +class TypeVarNode : public TypeNode { public: - /*! \brief possible kinds of TypeParam */ + /*! \brief possible kinds of TypeVar */ enum Kind : int { /*! \brief template variable in shape expression */ kType = 0, @@ -136,13 +136,13 @@ class TypeParamNode : public TypeNode { v->Visit("span", &span); } - TVM_DLL static TypeParam make(std::string name, Kind kind); + TVM_DLL static TypeVar make(std::string name, Kind kind); - static constexpr const char* _type_key = "relay.TypeParam"; - TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode); + static constexpr const char* _type_key = "relay.TypeVar"; + TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); +RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); /*! * \brief IncompleteType. @@ -150,20 +150,20 @@ RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type); * * If we view the type relations as "computational graph of types", * then IncompleteType represents intermediate values of the graph, - * TypeParam represents the input to the graph. + * TypeVar represents the input to the graph. */ class IncompleteType; /*! \brief IncompleteType container node */ class IncompleteTypeNode : public TypeNode { public: - TypeParamNode::Kind kind; + TypeVarNode::Kind kind; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); } - TVM_DLL static IncompleteType make(TypeParamNode::Kind kind); + TVM_DLL static IncompleteType make(TypeVarNode::Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); @@ -192,7 +192,7 @@ class FuncType; * Relay support polymorphic function type. * This can be roughly viewed as template function in C++. * - * \sa TypeParam, TypeConstraint + * \sa TypeVar, TypeConstraint */ class FuncTypeNode : public TypeNode { public: @@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode { // The following fields are used in polymorphic(template) functions // For normal functions, the following two fields will be empty. /*! \brief The type parameters of the function */ - tvm::Array type_params; + tvm::Array type_params; /*! * \brief potential constraint the type need to obey * \note this field is reserved for futher purposes. @@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode { TVM_DLL static FuncType make(tvm::Array arg_types, Type ret_type, - tvm::Array type_params, + tvm::Array type_params, tvm::Array type_constraints); static constexpr const char* _type_key = "relay.FuncType"; diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 4e53b6ba9aab..731a816460ee 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -28,7 +28,7 @@ TupleType = ty.TupleType TensorType = ty.TensorType Kind = ty.Kind -TypeParam = ty.TypeParam +TypeVar = ty.TypeVar TypeConstraint = ty.TypeConstraint FuncType = ty.FuncType TypeRelation = ty.TypeRelation diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index f3c61eec9155..824b0f20e281 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -56,7 +56,7 @@ class Kind(IntEnum): Shape = 3 @register_relay_node -class TypeParam(Type): +class TypeVar(Type): """A type parameter used for generic types in Relay, see tvm/relay/type.h for more details. @@ -66,7 +66,7 @@ class TypeParam(Type): """ def __init__(self, var, kind=Kind.Type): - """Construct a TypeParam. + """Construct a TypeVar. Parameters ---------- @@ -78,10 +78,10 @@ def __init__(self, var, kind=Kind.Type): Returns ------- - type_param: TypeParam + type_param: TypeVar The type parameter. """ - self.__init_handle_by_constructor__(_make.TypeParam, var, kind) + self.__init_handle_by_constructor__(_make.TypeVar, var, kind) @register_relay_node @@ -131,7 +131,7 @@ class FuncType(Type): ret_type: tvm.relay.Type The return type. - type_params: List[tvm.relay.TypeParam] + type_params: List[tvm.relay.TypeVar] The type parameters type_constraints: List[tvm.relay.TypeConstraint] diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 2d9180e4597b..6dfaa0b24a53 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -26,7 +26,7 @@ Environment EnvironmentNode::make(tvm::Map global_funcs) { return Environment(n); } -GlobalVar EnvironmentNode::GetGlobalVar(const std::string &name) { +GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) { auto it = global_var_map_.find(name); CHECK(it != global_var_map_.end()) << "Cannot find global var " << name << " in the Environment"; @@ -42,14 +42,11 @@ void EnvironmentNode::Add(const GlobalVar& var, auto type = checked_func->checked_type(); CHECK(type.as() == nullptr); if (functions.find(var) != functions.end()) { - if (!update) { - throw dmlc::Error("already have definition for XXXX."); - } + CHECK(update) + << "Already have definition for " << var->name_hint; auto old_type = functions[var].as()->checked_type(); - if (!AlphaEqual(type, old_type)) { - throw dmlc::Error( - "Environment#update changes type, not possible in this mode."); - } + CHECK(AlphaEqual(type, old_type)) + << "Environment#update changes type, not possible in this mode."; } this->functions.Set(var, checked_func); // set gloval var map @@ -70,12 +67,10 @@ void EnvironmentNode::Remove(const GlobalVar& var) { } Function EnvironmentNode::Lookup(const GlobalVar& var) { - auto func = functions.find(var); - if (func != functions.end()) { - return (*func).second; - } else { - throw Error(std::string("there is no definition of ") + var->name_hint); - } + auto it = functions.find(var); + CHECK(it != functions.end()) + << "There is no definition of " << var->name_hint; + return (*it).second; } Function EnvironmentNode::Lookup(const std::string &name) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index a1d274e3a78e..2d373b769559 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -104,7 +104,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) Function FunctionNode::make(tvm::Array params, Expr body, Type ret_type, - tvm::Array type_params) { + tvm::Array type_params) { NodePtr n = make_node(); n->params = std::move(params); n->body = std::move(body); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 26d9939aae10..a7367c384cb3 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -66,11 +66,11 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { } Expr ExprMutator::VisitExpr_(const FunctionNode* op) { - tvm::Array ty_params; + tvm::Array ty_params; bool all_ty_params_changed = true; for (auto ty_param : op->type_params) { - TypeParam new_ty_param = Downcast(VisitType(ty_param)); + TypeVar new_ty_param = Downcast(VisitType(ty_param)); ty_params.push_back(new_ty_param); all_ty_params_changed &= new_ty_param.same_as(ty_param); } diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index f45ab3b4c9a7..39347adced92 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -36,30 +36,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { - NodePtr n = make_node(); +TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { + NodePtr n = make_node(); n->var = tvm::Var(name); n->kind = std::move(kind); - return TypeParam(n); + return TypeVar(n); } -TVM_REGISTER_NODE_TYPE(TypeParamNode); +TVM_REGISTER_NODE_TYPE(TypeVarNode); -TVM_REGISTER_API("relay._make.TypeParam") +TVM_REGISTER_API("relay._make.TypeVar") .set_body([](TVMArgs args, TVMRetValue *ret) { int kind = args[1]; *ret = - TypeParamNode::make(args[0], static_cast(kind)); + TypeVarNode::make(args[0], static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const TypeParamNode *node, +.set_dispatch([](const TypeVarNode *node, tvm::IRPrinter *p) { - p->stream << "TypeParamNode(" << node->var->name_hint << ", " + p->stream << "TypeVarNode(" << node->var->name_hint << ", " << node->kind << ")"; }); -IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { +IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) { auto n = make_node(); n->kind = std::move(kind); return IncompleteType(n); @@ -70,7 +70,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_API("relay._make.IncompleteType") .set_body([](TVMArgs args, TVMRetValue* ret) { int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); + *ret = IncompleteTypeNode::make(static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, - tvm::Array type_params, + tvm::Array type_params, tvm::Array type_constraints) { NodePtr n = make_node(); n->arg_types = std::move(arg_types); diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 29d2f87cf04a..ef310d5ed8dc 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -34,7 +34,7 @@ bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { } struct TypeAlphaEq : TypeVisitor { - tvm::Map eq_map; + tvm::Map eq_map; bool equal; TypeAlphaEq() : eq_map(), equal(true) {} @@ -76,10 +76,10 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeParamNode* ti1, const Type& t2) final { - if (const TypeParamNode* ti2 = t2.as()) { - auto tid1 = GetRef(ti1); - auto tid2 = GetRef(ti2); + void VisitType_(const TypeVarNode* ti1, const Type& t2) final { + if (const TypeVarNode* ti2 = t2.as()) { + auto tid1 = GetRef(ti1); + auto tid2 = GetRef(ti2); // We handle open terms with this rule assuming variables are identical. // diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 72807985ced4..3f4d81b7e24f 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -20,7 +20,7 @@ namespace tvm { namespace relay { using namespace tvm::runtime; -using Kind = TypeParamNode::Kind; +using Kind = TypeVarNode::Kind; struct KindChecker : TypeVisitor<> { bool valid; @@ -33,7 +33,7 @@ struct KindChecker : TypeVisitor<> { return tv->kind == k; } - if (const TypeParamNode *tp = t.as()) { + if (const TypeVarNode *tp = t.as()) { return tp->kind == k; } diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 43b8bb8bba1d..904ceab36c3d 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -61,7 +61,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr); + return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr); } /*! diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index 81f93cacaa80..b8eaa85a73d2 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -61,7 +61,7 @@ class TypeFunctor { // Functions that can be overriden by subclass virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -79,7 +79,7 @@ class TypeFunctor { FType vtable; // Set dispatch RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode); - RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode); + RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode); RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3801987c932f..3e233274af2e 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -124,7 +124,7 @@ class TypeInferencer : private ExprFunctor { if (op->type_annotation.defined()) { return op->type_annotation; } else { - return IncompleteTypeNode::make(TypeParamNode::kType); + return IncompleteTypeNode::make(TypeVarNode::kType); } } @@ -157,7 +157,7 @@ class TypeInferencer : private ExprFunctor { EnvFunc::Get("tvm.relay.type_relation.TupleGetItem").node_); } Type tuple_type = GetType(op->tuple); - Type rtype = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); auto attrs = make_node(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( @@ -205,7 +205,7 @@ class TypeInferencer : private ExprFunctor { for (size_t i = 0; i < op->type_params.size(); ++i) { if (!op->type_params[i].same_as(rel->args[i])) return Type(); } - Type rtype = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here solver_.AddConstraint(TypeRelationNode::make( @@ -215,7 +215,7 @@ class TypeInferencer : private ExprFunctor { // instantiate the function type with fresh FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args) { - tvm::Map subst_map; + tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. @@ -232,7 +232,7 @@ class TypeInferencer : private ExprFunctor { // This is a temporary work around to check recursive functions whose // return type is not yet known. if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeParamNode::Kind::kType); + ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, ret_type, {}, diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc index 0b17fa0bc4f8..bffd779d1af2 100644 --- a/src/relay/pass/type_subst.cc +++ b/src/relay/pass/type_subst.cc @@ -10,13 +10,13 @@ namespace tvm { namespace relay { struct TypeSubstV : TypeMutator { - tvm::Map subst_map; + tvm::Map subst_map; - explicit TypeSubstV(tvm::Map subst_map) + explicit TypeSubstV(tvm::Map subst_map) : subst_map(subst_map) {} - Type VisitType_(const TypeParamNode* op) override { - auto id = GetRef(op); + Type VisitType_(const TypeVarNode* op) override { + auto id = GetRef(op); if (subst_map.find(id) != subst_map.end()) { return this->subst_map[id]; } else { @@ -25,12 +25,12 @@ struct TypeSubstV : TypeMutator { } }; -Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst) { +Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst) { TypeSubstV ty_sub({ {target, subst} }); return ty_sub.VisitType(type); } -Type TypeSubst(const Type& type, tvm::Map subst_map) { +Type TypeSubst(const Type& type, tvm::Map subst_map) { TypeSubstV ty_sub(subst_map); return ty_sub.VisitType(type); } diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h index aee3209afb7a..808e3536ae30 100644 --- a/src/relay/pass/type_subst.h +++ b/src/relay/pass/type_subst.h @@ -11,8 +11,8 @@ namespace tvm { namespace relay { -Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst); -Type TypeSubst(const Type& type, tvm::Map subst_map); +Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst); +Type TypeSubst(const Type& type, tvm::Map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 6468269686e8..c1b2c3e1a3ad 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -19,7 +19,7 @@ namespace relay { */ template struct TypeVisitor : ::tvm::relay::TypeFunctor { - void VisitType_(const TypeParamNode* op, Args... args) override {} + void VisitType_(const TypeVarNode* op, Args... args) override {} void VisitType_(const FuncTypeNode* op, Args... args) override { for (auto type_param : op->type_params) { @@ -60,16 +60,16 @@ struct TypeMutator : TypeFunctor { return TensorTypeNode::make(op->shape, op->dtype); } - Type VisitType_(const TypeParamNode* op) override { - return GetRef(op); + Type VisitType_(const TypeVarNode* op) override { + return GetRef(op); } Type VisitType_(const FuncTypeNode* op) override { - Array type_params; + Array type_params; for (auto type_param : op->type_params) { auto new_type_param = VisitType(type_param); - if (const TypeParamNode* tin = new_type_param.as()) { - type_params.push_back(GetRef(tin)); + if (const TypeVarNode* tin = new_type_param.as()) { + type_params.push_back(GetRef(tin)); } else { CHECK(false) << new_type_param << std::endl; } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index c845995b2003..8ebac921203f 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -14,14 +14,14 @@ namespace relay { class FreeVar; class FreeTypeVar : private TypeVisitor<> { - std::unordered_set * free_vars; - std::unordered_set * bound_vars; - FreeTypeVar(std::unordered_set * free_vars, - std::unordered_set * bound_vars) : + std::unordered_set * free_vars; + std::unordered_set * bound_vars; + FreeTypeVar(std::unordered_set * free_vars, + std::unordered_set * bound_vars) : free_vars(free_vars), bound_vars(bound_vars) { } - void VisitType_(const TypeParamNode* tp) final { - auto var = GetRef(tp); + void VisitType_(const TypeVarNode* tp) final { + auto var = GetRef(tp); if (bound_vars->count(var) == 0) { free_vars->insert(var); } @@ -75,8 +75,8 @@ class FreeVar : public ExprVisitor { public: std::unordered_set free_vars; std::unordered_set bound_vars; - std::unordered_set free_types; - std::unordered_set bound_types; + std::unordered_set free_types; + std::unordered_set bound_types; void VisitType(const Type& t) final { FreeTypeVar(&free_types, &bound_types)(t); @@ -89,16 +89,16 @@ tvm::Array FreeVariables(const Expr& e) { return tvm::Array(fv.free_vars.begin(), fv.free_vars.end()); } -tvm::Array FreeTypeVariables(const Expr& e) { +tvm::Array FreeTypeVariables(const Expr& e) { FreeVar fv; fv.VisitExpr(e); - return tvm::Array(fv.free_types.begin(), fv.free_types.end()); + return tvm::Array(fv.free_types.begin(), fv.free_types.end()); } -tvm::Array FreeTypeVariables(const Type& t) { +tvm::Array FreeTypeVariables(const Type& t) { FreeVar fv; fv.VisitType(t); - return tvm::Array(fv.free_types.begin(), fv.free_types.end()); + return tvm::Array(fv.free_types.begin(), fv.free_types.end()); } TVM_REGISTER_API("relay._ir_pass.free_vars") diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 20c45f5a16c5..fc9f30c9a61d 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -34,7 +34,7 @@ def test_tensor_type(): def test_type_param(): - tp = relay.TypeParam('name', relay.Kind.Type) + tp = relay.TypeVar('name', relay.Kind.Type) assert tp.kind == relay.Kind.Type # assert tp.span # TODO allow us to set span str(tp) @@ -56,7 +56,7 @@ def test_func_type(): def test_tuple_type(): - tp = relay.TypeParam('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.Kind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') fields = tvm.convert([tp, tf, tt]) @@ -66,7 +66,7 @@ def test_tuple_type(): def test_type_relation(): - tp = relay.TypeParam('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.Kind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') args = tvm.convert([tf, tt, tp]) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 18959687a0fd..983f3e4a13bc 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -28,9 +28,9 @@ def test_incomplete_type_alpha_equal(): def test_type_param_alpha_equal(): - t1 = relay.TypeParam("v1", relay.Kind.Type) - t2 = relay.TypeParam("v2", relay.Kind.Shape) - t3 = relay.TypeParam("v3", relay.Kind.Type) + t1 = relay.TypeVar("v1", relay.Kind.Type) + t2 = relay.TypeVar("v2", relay.Kind.Shape) + t3 = relay.TypeVar("v3", relay.Kind.Type) # only pointer equality and eq_map allow equal params assert t1 == t1 @@ -53,10 +53,10 @@ def test_func_type_alpha_equal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") - tp1 = relay.TypeParam("v1", relay.Kind.Type) - tp2 = relay.TypeParam("v2", relay.Kind.Type) - tp3 = relay.TypeParam("v3", relay.Kind.Shape) - tp4 = relay.TypeParam("v3", relay.Kind.Shape) + tp1 = relay.TypeVar("v1", relay.Kind.Type) + tp2 = relay.TypeVar("v2", relay.Kind.Type) + tp3 = relay.TypeVar("v3", relay.Kind.Shape) + tp4 = relay.TypeVar("v3", relay.Kind.Shape) broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") identity = tvm.get_env_func("tvm.relay.type_relation.Identity") @@ -112,8 +112,8 @@ def test_func_type_alpha_equal(): def test_tuple_type_alpha_equal(): t1 = relay.TensorType((1, 2, 3), "float32") t2 = relay.TensorType((1, 2, 3, 4), "float32") - tp1 = relay.TypeParam("v1", relay.Kind.Type) - tp2 = relay.TypeParam("v2", relay.Kind.Type) + tp1 = relay.TypeVar("v1", relay.Kind.Type) + tp2 = relay.TypeVar("v2", relay.Kind.Type) tup1 = relay.TupleType(tvm.convert([t1, t2, tp1])) tup2 = relay.TupleType(tvm.convert([t1, t2, tp1])) @@ -253,10 +253,10 @@ def test_function_alpha_equal(): v4 = relay.Var("v4", tt2) vret = relay.Constant(tvm.nd.array(np.ones(1))) - tp1 = relay.TypeParam("tp1", relay.Kind.Type) - tp2 = relay.TypeParam("tp2", relay.Kind.Type) - tp3 = relay.TypeParam("tp3", relay.Kind.Shape) - tp4 = relay.TypeParam("tp4", relay.Kind.Shape) + tp1 = relay.TypeVar("tp1", relay.Kind.Type) + tp2 = relay.TypeVar("tp2", relay.Kind.Type) + tp3 = relay.TypeVar("tp3", relay.Kind.Shape) + tp4 = relay.TypeVar("tp4", relay.Kind.Shape) basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)] basic_tps = [tp1, tp2] diff --git a/tests/python/relay/test_pass_check_kind.py b/tests/python/relay/test_pass_check_kind.py index 314c8c8b7992..5ead501157c5 100644 --- a/tests/python/relay/test_pass_check_kind.py +++ b/tests/python/relay/test_pass_check_kind.py @@ -4,7 +4,7 @@ def test_tuple_kind(): # only contain type kinds - tp = relay.TypeParam('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.Kind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) fields = tvm.convert([tp, tf, tt]) @@ -15,8 +15,8 @@ def test_tuple_kind(): def test_func_kind(): # only contain type kinds - tp1 = relay.TypeParam('tp1', relay.Kind.Type) - tp2 = relay.TypeParam('tp2', relay.Kind.Type) + tp1 = relay.TypeVar('tp1', relay.Kind.Type) + tp2 = relay.TypeVar('tp2', relay.Kind.Type) shape = tvm.convert([1, 2, 3]) dtype = 'float32' @@ -35,7 +35,7 @@ def test_func_kind(): def test_relation_kind(): # only have type kinds for arguments - tp = relay.TypeParam('tp', relay.Kind.Type) + tp = relay.TypeVar('tp', relay.Kind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) args = tvm.convert([tf, tt, tp]) @@ -45,9 +45,9 @@ def test_relation_kind(): def test_invalid_tuple_kind(): - tp1 = relay.TypeParam('tp1', relay.Kind.Shape) - tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) - tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.Kind.Shape) + tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) + tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) fields = tvm.convert([tp1, tp2, tp3]) tup_ty = relay.TupleType(fields) @@ -55,9 +55,9 @@ def test_invalid_tuple_kind(): def test_invalid_func_kind(): - tp1 = relay.TypeParam('tp1', relay.Kind.Shape) - tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) - tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.Kind.Shape) + tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) + tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) type_params = tvm.convert([tp1, tp2, tp3]) type_constraints = tvm.convert([]) @@ -69,9 +69,9 @@ def test_invalid_func_kind(): def test_invalid_relation_kind(): - tp1 = relay.TypeParam('tp1', relay.Kind.Shape) - tp2 = relay.TypeParam('tp2', relay.Kind.BaseType) - tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.Kind.Shape) + tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) + tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) args = tvm.convert([tp1, tp2, tp3]) tr = relay.TypeRelation(None, args, 2, None) @@ -79,19 +79,19 @@ def test_invalid_relation_kind(): def test_func_with_invalid_ret_type(): - tp1 = relay.TypeParam('tp1', relay.Kind.Type) - tp2 = relay.TypeParam('tp2', relay.Kind.Shape) + tp1 = relay.TypeVar('tp1', relay.Kind.Type) + tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) def test_func_with_invalid_arg_types(): - tp1 = relay.TypeParam('tp1', relay.Kind.Shape) - tp2 = relay.TypeParam('tp2', relay.Kind.Type) + tp1 = relay.TypeVar('tp1', relay.Kind.Shape) + tp2 = relay.TypeVar('tp2', relay.Kind.Type) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) def test_func_with_invalid_tuple(): - tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + tp1 = relay.TypeVar('tp1', relay.Kind.Shape) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) @@ -100,9 +100,9 @@ def test_func_with_invalid_tuple(): def test_func_with_invalid_relation(): - tp1 = relay.TypeParam('tp1', relay.Kind.Type) - tp2 = relay.TypeParam('tp2', relay.Kind.Shape) - tp3 = relay.TypeParam('tp3', relay.Kind.ShapeVar) + tp1 = relay.TypeVar('tp1', relay.Kind.Type) + tp2 = relay.TypeVar('tp2', relay.Kind.Shape) + tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) @@ -113,7 +113,7 @@ def test_func_with_invalid_relation(): def test_tuple_with_invalid_func(): tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') - tp1 = relay.TypeParam('tp1', relay.Kind.Shape) + tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py index 524196661753..151dbe1412bc 100644 --- a/tests/python/relay/test_pass_free_vars.py +++ b/tests/python/relay/test_pass_free_vars.py @@ -28,7 +28,7 @@ def test_tuple(): def test_free_type_vars(): - tp = relay.TypeParam("") + tp = relay.TypeVar("") ty = relay.TupleType([tp, relay.TensorType([], "int32")]) x = relay.Var("x", ty) y = relay.Var("y")