Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] IR builder stablize refactor, clean pass #1934

Merged
merged 2 commits into from
Oct 20, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis.");
}
Expand Down
28 changes: 18 additions & 10 deletions include/tvm/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,52 +47,60 @@ class EnvironmentNode : public RelayNode {

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_);
v->Visit("global_var_map_", &global_var_map_);
}

TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);

/*! \brief Add a function to the global environment.
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);

/*! \brief Update a function in the global environment.
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);

/*! \brief Remove a function from the global environment.
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);

/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);

/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);

/*! \brief Lookup a global function by its string name
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);

/*! \brief Combine with another Environment.
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void Merge(const Environment& other);
void Update(const Environment& other);

static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
Expand All @@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
tvm::Map<std::string, GlobalVar> global_var_map_;
};

struct Environment : public NodeRef {
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
int index;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.GetItem";
static constexpr const char * _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

Expand Down
23 changes: 16 additions & 7 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@
namespace tvm {
namespace relay {

/*! \brief Infer the type of an expression with the provided environment.
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param env The environment used for global settings and referencing
* global functions.
*
* \param e The expression to type check.
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& var, const Function& f);
Expr InferType(const Expr& expr, const Environment& env);
/*!
* \brief Infer the type of a function as if it is mapped to var in the env.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);

/*!
* \brief Check that types are well kinded by applying "kinding rules".
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from . import expr
from . import env
from . import ir_pass
from . import ir_builder

# Root operators
from .op import Op
Expand All @@ -16,6 +15,8 @@
from . import vision
from . import image

from .scope_builder import ScopeBuilder

# Span
Span = base.Span

Expand All @@ -32,6 +33,7 @@
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type

# Expr
Constant = expr.Constant
Expand Down
106 changes: 60 additions & 46 deletions python/tvm/relay/env.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, RelayNode
from .._ffi import base as _base
from . import _make
from . import _env
from . import expr as _expr


@register_relay_node
class Environment(RelayNode):
"""The global Relay environment containing functions,
options and more.
"""

def __init__(self, funcs=None):
"""Construct an environment.

Parameters
------
funcs : optional, dict
Map of global var to Function
"""The global Relay environment containing collection of functions.

Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
"""
funcs = funcs if funcs else {}
self.__init_handle_by_constructor__(_make.Environment, funcs)
Each global function is identified by an unique tvm.relay.GlobalVar.
tvm.relay.GlobalVar and Environment is necessary in order to enable
recursions in function to avoid cyclic reference in the function.x

def add(self, var, func):
Parameters
----------
functions : dict, optional.
Map of global var to Function
"""
def __init__(self, functions=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
mapped_funcs = {}
for k, v in functions.items():
if isinstance(k, _base.string_types):
k = _expr.GlobalVar(k)
if not isinstance(k, _expr.GlobalVar):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v
functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Environment, functions)

def __setitem__(self, var, func):
"""Add a function to the environment.

Parameters
Expand All @@ -36,50 +45,55 @@ def add(self, var, func):
func: Function
The function.
"""
if isinstance(var, str):
var = _env.Environment_GetGlobalVar(self, var)

if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)
_env.Environment_Add(self, var, func)

def merge(self, other):
"""Merge two environments.
def __getitem__(self, var):
"""Lookup a global function by name or by variable.

Parameters
----------
other: Environment
The environment to merge into the current Environment.
var: str or GlobalVar
The name or global variable.

Returns
-------
func: Function
The function referenced by :code:`var`.
"""
return _env.Environment_Merge(self, other)
if isinstance(var, _base.string_types):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)

def global_var(self, name):
"""Get a global variable by name.
def update(self, other):
"""Insert functions in another Environment to current one.

Parameters
----------
name: str
The name of the global variable.

Returns
-------
global_var: GlobalVar
The global variable mapped to :code:`name`.
other: Environment
The environment to merge into the current Environment.
"""
return _env.Environment_GetGlobalVar(self, name)
if isinstance(other, dict):
other = Environment(other)
return _env.Environment_Update(self, other)

def __getitem__(self, var):
"""Lookup a global function by name or by variable.
def get_global_var(self, name):
"""Get a global variable in the function by name.

Parameters
----------
var: str or GlobalVar
The name or global variable.
name: str
The name of the global variable.

Returns
-------
func: Function
The function referenced by :code:`var`.
global_var: GlobalVar
The global variable mapped to :code:`name`.

Raises
------
tvm.TVMError if we cannot find corresponding global var.
"""
if isinstance(var, str):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
return _env.Environment_GetGlobalVar(self, name)
Loading