Skip to content

Commit

Permalink
[Relay] Unifier hotfix (apache#2437)
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and AWS Neo committed Feb 20, 2019
1 parent 41328a8 commit 2b8a5c0
Show file tree
Hide file tree
Showing 12 changed files with 1,072 additions and 254 deletions.
68 changes: 68 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/
bool WellFormed(const Expr& expr);

/*! \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> BoundVars(const Expr& expr);

/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
Expand All @@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr);
*/
tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> AllVars(const Expr& expr);

/*! \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function
Expand All @@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr);
*/
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);

/*! \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param t the type.
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> FreeTypeVars(const Type& t);

/*! \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);

/*! \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
*
* \param t the type
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> BoundTypeVars(const Type& t);

/*! \brief Get all type variables in expression expr.
*
* \param expr the expression.
*
* \return List of type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> AllTypeVars(const Expr& expr);

/*! \brief Get all type variables in type t.
*
* \param t the type.
*
* \return List of type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> AllTypeVars(const Type& t);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced, and branches that will
Expand Down
68 changes: 66 additions & 2 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ def free_vars(expr):
return _ir_pass.free_vars(expr)


def bound_vars(expr):
"""Get bound vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return _ir_pass.bound_vars(expr)


def all_vars(expr):
"""Get all vars from expression expr in post-DFS order.
Parameters
----------
expr: tvm.relay.Expr
The input expression
Returns
-------
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return _ir_pass.all_vars(expr)


def free_type_vars(expr):
"""Get free type variables from expression/type e
Expand All @@ -168,12 +200,44 @@ def free_type_vars(expr):
Returns
-------
free : List[tvm.relay.TypeParam]
The list of free type variables
free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order
"""
return _ir_pass.free_type_vars(expr)


def bound_type_vars(expr):
"""Get bound type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return _ir_pass.bound_type_vars(expr)


def all_type_vars(expr):
"""Get all type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return _ir_pass.all_type_vars(expr)


def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase.
Expand Down
19 changes: 15 additions & 4 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
});
return Pair(res.foward, grad);
});

// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type ret_type = Type();
std::vector<Type> vt;
bool missing = !f->ret_type.defined();
for (const auto& p : f->params) {
if (missing || !p->type_annotation.defined()) {
missing = true;
break;
}
vt.push_back(p->type_annotation);
}
return FunctionNode::make(f->params,
body,
TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}),
{});

if (!missing) {
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}

return FunctionNode::make(f->params, body, ret_type, {});
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
Expand Down
Loading

0 comments on commit 2b8a5c0

Please sign in to comment.