Skip to content

Commit

Permalink
[RELAY] IR builder stablize refactor, clean pass (apache#1934)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and AWS Neo committed Feb 20, 2019
1 parent e9d6157 commit 06f14a7
Show file tree
Hide file tree
Showing 52 changed files with 1,212 additions and 1,836 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis.");
}
Expand Down
28 changes: 18 additions & 10 deletions include/tvm/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,52 +47,60 @@ class EnvironmentNode : public RelayNode {

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_);
v->Visit("global_var_map_", &global_var_map_);
}

TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);

/*! \brief Add a function to the global environment.
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);

/*! \brief Update a function in the global environment.
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);

/*! \brief Remove a function from the global environment.
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);

/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);

/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);

/*! \brief Lookup a global function by its string name
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);

/*! \brief Combine with another Environment.
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void Merge(const Environment& other);
void Update(const Environment& other);

static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
Expand All @@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
tvm::Map<std::string, GlobalVar> global_var_map_;
};

struct Environment : public NodeRef {
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class FunctionNode : public ExprNode {
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeParam> type_params;
tvm::Array<TypeVar> type_params;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
Expand All @@ -219,7 +219,7 @@ class FunctionNode : public ExprNode {
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params);
tvm::Array<TypeVar> ty_params);

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
Expand Down Expand Up @@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
int index;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.GetItem";
static constexpr const char * _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,22 +371,22 @@ inline OpRegistry& OpRegistry::add_type_rel(
env_type_rel_func = env_func;
}

Array<TypeParam> type_params;
Array<TypeVar> type_params;
Array<Type> arg_types;

// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType);
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}

Array<Type> ty_call_args = arg_types;

// Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
Expand Down
27 changes: 18 additions & 9 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@
namespace tvm {
namespace relay {

/*! \brief Infer the type of an expression with the provided environment.
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param env The environment used for global settings and referencing
* global functions.
*
* \param e The expression to type check.
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& var, const Function& f);
Expr InferType(const Expr& expr, const Environment& env);
/*!
* \brief Infer the type of a function as if it is mapped to var in the env.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);

/*!
* \brief Check that types are well kinded by applying "kinding rules".
Expand Down Expand Up @@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e);
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);

/*! \brief Get free type parameters from type t.
*
Expand All @@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Type& t);
tvm::Array<TypeVar> FreeTypeVariables(const Type& t);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down
32 changes: 16 additions & 16 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeParam of f is TypeParam(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
Expand All @@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeParamNode The actual container class of TypeParam
* \sa TypeVarNode The actual container class of TypeVar
*/
class TypeParam;
/*! \brief TypeParam container node */
class TypeParamNode : public TypeNode {
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*! \brief possible kinds of TypeParam */
/*! \brief possible kinds of TypeVar */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
Expand All @@ -136,34 +136,34 @@ class TypeParamNode : public TypeNode {
v->Visit("span", &span);
}

TVM_DLL static TypeParam make(std::string name, Kind kind);
TVM_DLL static TypeVar make(std::string name, Kind kind);

static constexpr const char* _type_key = "relay.TypeParam";
TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);

/*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeParam represents the input to the graph.
* TypeVar represents the input to the graph.
*/
class IncompleteType;

/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
TypeVarNode::Kind kind;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
}

TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
TVM_DLL static IncompleteType make(TypeVarNode::Kind kind);

static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
Expand Down Expand Up @@ -192,7 +192,7 @@ class FuncType;
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeParam, TypeConstraint
* \sa TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
Expand All @@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode {
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeParam> type_params;
tvm::Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
Expand All @@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode {

TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints);

static constexpr const char* _type_key = "relay.FuncType";
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from . import expr
from . import env
from . import ir_pass
from . import ir_builder

# Root operators
from .op import Op
Expand All @@ -16,6 +15,8 @@
from . import vision
from . import image

from .scope_builder import ScopeBuilder

# Span
Span = base.Span

Expand All @@ -27,11 +28,12 @@
TupleType = ty.TupleType
TensorType = ty.TensorType
Kind = ty.Kind
TypeParam = ty.TypeParam
TypeVar = ty.TypeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type

# Expr
Constant = expr.Constant
Expand Down
Loading

0 comments on commit 06f14a7

Please sign in to comment.