Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] IR builder stablize refactor, clean pass #1934

Merged
merged 2 commits into from
Oct 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
double alpha;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") {
TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25)
.describe("Slope coefficient for the negative half axis.");
}
Expand Down
28 changes: 18 additions & 10 deletions include/tvm/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,52 +47,60 @@ class EnvironmentNode : public RelayNode {

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_);
v->Visit("global_var_map_", &global_var_map_);
}

TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);

/*! \brief Add a function to the global environment.
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);

/*! \brief Update a function in the global environment.
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);

/*! \brief Remove a function from the global environment.
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);

/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);

/*! \brief Lookup a global function by its variable.
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);

/*! \brief Lookup a global function by its string name
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);

/*! \brief Combine with another Environment.
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void Merge(const Environment& other);
void Update(const Environment& other);

static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
Expand All @@ -101,7 +109,7 @@ class EnvironmentNode : public RelayNode {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
tvm::Map<std::string, GlobalVar> global_var_map_;
};

struct Environment : public NodeRef {
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class FunctionNode : public ExprNode {
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeParam> type_params;
tvm::Array<TypeVar> type_params;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
Expand All @@ -219,7 +219,7 @@ class FunctionNode : public ExprNode {
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params);
tvm::Array<TypeVar> ty_params);

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
Expand Down Expand Up @@ -375,13 +375,14 @@ class TupleGetItemNode : public ExprNode {
int index;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.GetItem";
static constexpr const char * _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,22 +371,22 @@ inline OpRegistry& OpRegistry::add_type_rel(
env_type_rel_func = env_func;
}

Array<TypeParam> type_params;
Array<TypeVar> type_params;
Array<Type> arg_types;

// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType);
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}

Array<Type> ty_call_args = arg_types;

// Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
Expand Down
27 changes: 18 additions & 9 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@
namespace tvm {
namespace relay {

/*! \brief Infer the type of an expression with the provided environment.
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param env The environment used for global settings and referencing
* global functions.
*
* \param e The expression to type check.
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& var, const Function& f);
Expr InferType(const Expr& expr, const Environment& env);
/*!
* \brief Infer the type of a function as if it is mapped to var in the env.
*
* \param f the function.
* \param env The environment used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);

/*!
* \brief Check that types are well kinded by applying "kinding rules".
Expand Down Expand Up @@ -111,7 +120,7 @@ tvm::Array<Var> FreeVariables(const Expr& e);
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
tvm::Array<TypeVar> FreeTypeVariables(const Expr& e);

/*! \brief Get free type parameters from type t.
*
Expand All @@ -121,7 +130,7 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr& e);
*
* \return the set of free type variables.
*/
tvm::Array<TypeParam> FreeTypeVariables(const Type& t);
tvm::Array<TypeVar> FreeTypeVariables(const Type& t);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down
32 changes: 16 additions & 16 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeParam of f is TypeParam(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
Expand All @@ -108,13 +108,13 @@ RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeParamNode The actual container class of TypeParam
* \sa TypeVarNode The actual container class of TypeVar
*/
class TypeParam;
/*! \brief TypeParam container node */
class TypeParamNode : public TypeNode {
class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*! \brief possible kinds of TypeParam */
/*! \brief possible kinds of TypeVar */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
Expand All @@ -136,34 +136,34 @@ class TypeParamNode : public TypeNode {
v->Visit("span", &span);
}

TVM_DLL static TypeParam make(std::string name, Kind kind);
TVM_DLL static TypeVar make(std::string name, Kind kind);

static constexpr const char* _type_key = "relay.TypeParam";
TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode);
static constexpr const char* _type_key = "relay.TypeVar";
TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);

/*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
* If we view the type relations as "computational graph of types",
* then IncompleteType represents intermediate values of the graph,
* TypeParam represents the input to the graph.
* TypeVar represents the input to the graph.
*/
class IncompleteType;

/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
TypeVarNode::Kind kind;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
}

TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
TVM_DLL static IncompleteType make(TypeVarNode::Kind kind);

static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
Expand Down Expand Up @@ -192,7 +192,7 @@ class FuncType;
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeParam, TypeConstraint
* \sa TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
Expand All @@ -203,7 +203,7 @@ class FuncTypeNode : public TypeNode {
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeParam> type_params;
tvm::Array<TypeVar> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
Expand All @@ -220,7 +220,7 @@ class FuncTypeNode : public TypeNode {

TVM_DLL static FuncType make(tvm::Array<Type> arg_types,
Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeVar> type_params,
tvm::Array<TypeConstraint> type_constraints);

static constexpr const char* _type_key = "relay.FuncType";
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from . import expr
from . import env
from . import ir_pass
from . import ir_builder

# Root operators
from .op import Op
Expand All @@ -16,6 +15,8 @@
from . import vision
from . import image

from .scope_builder import ScopeBuilder

# Span
Span = base.Span

Expand All @@ -27,11 +28,12 @@
TupleType = ty.TupleType
TensorType = ty.TensorType
Kind = ty.Kind
TypeParam = ty.TypeParam
TypeVar = ty.TypeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type

# Expr
Constant = expr.Constant
Expand Down
Loading