From 76188a43d062ecf87721f6672dad0d9f714efc6d Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 17 Jan 2019 04:32:38 +0530 Subject: [PATCH 01/16] [NNVM][TENSORFLOW] bugfix. (#2444) --- nnvm/python/nnvm/frontend/tensorflow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index c0848bb1092c4..9a302da72ae68 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1193,6 +1193,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList( \ tensor_value.tensor_shape)] + elif shape and node.name in shape: + # Give priority to user argument. + self._output_shapes[node.name] = [shape[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ From 6783d373760ad114209e2fcce167174907a97ae7 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 15:07:02 -0800 Subject: [PATCH 02/16] [Relay] Unifier hotfix (#2437) --- include/tvm/relay/pass.h | 68 +++++ python/tvm/relay/ir_pass.py | 68 ++++- src/relay/pass/gradient.cc | 19 +- src/relay/pass/type_infer.cc | 137 +++++---- src/relay/pass/type_solver.cc | 343 +++++++++++++++++++--- src/relay/pass/type_solver.h | 31 +- src/relay/pass/util.cc | 214 +++++++++++--- tests/cpp/relay_pass_type_infer_test.cc | 16 +- tests/python/relay/test_pass_free_vars.py | 41 --- tests/python/relay/test_pass_vars.py | 144 +++++++++ tests/python/relay/test_type_infer.py | 81 +++-- tests/python/relay/test_type_solver.py | 164 +++++++++++ 12 files changed, 1072 insertions(+), 254 deletions(-) delete mode 100644 tests/python/relay/test_pass_free_vars.py create mode 100644 tests/python/relay/test_pass_vars.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1897809f48b80..566d69cc6b0b8 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); */ bool WellFormed(const Expr& expr); +/*! \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +tvm::Array BoundVars(const Expr& expr); + /*! \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a @@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr); */ tvm::Array FreeVars(const Expr& expr); +/*! \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +tvm::Array AllVars(const Expr& expr); + /*! \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function @@ -130,6 +149,55 @@ tvm::Array FreeVars(const Expr& expr); */ tvm::Array FreeTypeVars(const Expr& expr); +/*! \brief Get free TypeVars from type t. + * + * Free type parameters are type parameters that are not bound by a function + * type in the context. + * + * \param t the type. + * + * \return List of free type vars, in the PostDFS order visited by type. + */ +tvm::Array FreeTypeVars(const Type& t); + +/*! \brief Get all bound type variables from expression expr. + * + * Bound variables are all type variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound type vars, in the PostDFS order in the expression. + */ +tvm::Array BoundTypeVars(const Expr& expr); + +/*! \brief Get all bound type variables from type t. + * + * Bound variables are all type variables that are declared in the type. + * They only have meaning inside that type, and can only be used in it. + * + * \param t the type + * + * \return List of bound type vars, in the PostDFS order visited by type. + */ +tvm::Array BoundTypeVars(const Type& t); + +/*! \brief Get all type variables in expression expr. + * + * \param expr the expression. + * + * \return List of type vars, in the PostDFS order in the expression. + */ +tvm::Array AllTypeVars(const Expr& expr); + +/*! \brief Get all type variables in type t. + * + * \param t the type. + * + * \return List of type vars, in the PostDFS order visited by type. + */ +tvm::Array AllTypeVars(const Type& t); + /*! \brief Remove expressions which does not effect the program result. * * It will remove let bindings which are not referenced, and branches that will diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 1bec7ccd72d58..d5d5e9261fc79 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -158,6 +158,38 @@ def free_vars(expr): return _ir_pass.free_vars(expr) +def bound_vars(expr): + """Get bound vars from expression expr in post-DFS order. + + Parameters + ---------- + expr: tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of bound variables in post-DFS order. + """ + return _ir_pass.bound_vars(expr) + + +def all_vars(expr): + """Get all vars from expression expr in post-DFS order. + + Parameters + ---------- + expr: tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of all variables in post-DFS order. + """ + return _ir_pass.all_vars(expr) + + def free_type_vars(expr): """Get free type variables from expression/type e @@ -168,12 +200,44 @@ def free_type_vars(expr): Returns ------- - free : List[tvm.relay.TypeParam] - The list of free type variables + free : List[tvm.relay.TypeVar] + The list of free type variables in post-DFS order """ return _ir_pass.free_type_vars(expr) +def bound_type_vars(expr): + """Get bound type variables from expression/type e + + Parameters + ---------- + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of bound type variables in post-DFS order + """ + return _ir_pass.bound_type_vars(expr) + + +def all_type_vars(expr): + """Get all type variables from expression/type e + + Parameters + ---------- + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of all type variables in post-DFS order + """ + return _ir_pass.all_type_vars(expr) + + def simplify_inference(expr): """ Simplify the data-flow graph for inference phase. diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 601a09b35d1af..251d7153e4e6d 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { }); return Pair(res.foward, grad); }); + + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + Type ret_type = Type(); std::vector vt; + bool missing = !f->ret_type.defined(); for (const auto& p : f->params) { + if (missing || !p->type_annotation.defined()) { + missing = true; + break; + } vt.push_back(p->type_annotation); } - return FunctionNode::make(f->params, - body, - TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}), - {}); + + if (!missing) { + ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); + } + + return FunctionNode::make(f->params, body, ret_type, {}); } TVM_REGISTER_API("relay._ir_pass.first_order_gradient") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ee1b5ab101482..af4cc6607a44a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array& types, return true; } -bool MakeTupleRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(static_cast(num_inputs + 1), types.size()); - for (int i = 0; i < num_inputs; ++i) { - if (types[i].as()) return false; - } - Array fields; - for (int i = 0; i < num_inputs; ++i) { - fields.push_back(types[i]); - } - reporter->Assign(types[num_inputs], TupleTypeNode::make(fields)); - return true; -} - TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); -TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple") -.set_body_typed&, int, const Attrs&, const TypeReporter&)>( - MakeTupleRel); - struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) : checked_type(checked_type), type_args(type_args) {} @@ -120,6 +100,10 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; + // used to ensure we don't have free type vars hanging around + // (a temporary measure until we have proper generalization implemented) + Map instantiation_map_; + // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -140,6 +124,32 @@ class TypeInferencer : private ExprFunctor { return Type(); } } + + // Substitutes every type var in t with a corresponding incomplete type. + // This is a temporary measure to ensure type vars behave until + // generalization is properly implemented. + Type Instantiate(const Type &t) { + if (!t.defined()) { + return t; + } + auto* ft = t.as(); + if (ft == nullptr) { + return Bind(t, instantiation_map_); + } + + for (auto type_param : ft->type_params) { + instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + + Type ret_type = ft->ret_type; + if (!ret_type.defined()) { + ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + } + + auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); + return Bind(strip_tvs, instantiation_map_); + } + // Lazily get type for expr // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { @@ -147,7 +157,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = this->VisitExpr(expr); + Type ret = Instantiate(this->VisitExpr(expr)); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -175,19 +185,11 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const TupleNode* op) final { - if (!make_tuple_rel_.defined()) { - make_tuple_rel_ = TypeRelationFn( - EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_); - } Array types; for (Expr field : op->fields) { types.push_back(GetType(field)); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - types.push_back(rtype); - solver_.AddConstraint(TypeRelationNode::make( - make_tuple_rel_, types, op->fields.size(), Attrs())); - return rtype; + return TupleTypeNode::make(types); } Type VisitExpr_(const TupleGetItemNode* op) final { @@ -209,11 +211,17 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const LetNode* op) final { + // if the definition is a function literal, permit recursion + bool is_functional_literal = op->value.as() != nullptr; + if (is_functional_literal) { + type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + } + Type vtype = GetType(op->value); if (op->var->type_annotation.defined()) { vtype = Unify(vtype, op->var->type_annotation, op->span); } - CHECK(!type_map_.count(op->var)); + CHECK(is_functional_literal || !type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program type_map_[op->var].checked_type = vtype; return GetType(op->body); @@ -252,16 +260,14 @@ class TypeInferencer : private ExprFunctor { return rtype; } - // instantiate the function type with fresh - FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args) { + // substitute the type args in the function type + FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array& ty_args) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (auto ty_param : fn_ty->type_params) { - IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); - subst_map.Set(ty_param, fresh); - ty_args->push_back(fresh); + for (size_t i = 0; i < fn_ty->type_params.size(); i++) { + subst_map.Set(fn_ty->type_params[i], ty_args[i]); } Type ret_type = fn_ty->ret_type; @@ -296,13 +302,32 @@ class TypeInferencer : private ExprFunctor { Type GeneralCall(const CallNode* call, Array arg_types) { Type ftype = GetType(call->op); auto* fn_ty_node = ftype.as(); + auto* inc_ty_node = ftype.as(); + + CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) + << "only expressions with function types can be called, found " + << ftype << " at " << call->span; + + // incomplete type => it must be a function taking the arg types + // with an unknown return type + if (inc_ty_node != nullptr) { + Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); + Type unified = this->Unify(ftype, func_type, call->span); + fn_ty_node = unified.as(); + } - CHECK(fn_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; - - Array type_args; - FuncType fn_ty = Instantiate(fn_ty_node, &type_args); + Array type_args = call->type_args; + if (type_args.size() == 0) { + for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { + type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + } + CHECK(type_args.size() == fn_ty_node->type_params.size()) + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() + << "but got " << type_args.size(); + FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); AddTypeArgs(GetRef(call), type_args); @@ -353,26 +378,17 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const FunctionNode* f) final { + solver_.Solve(); + Array arg_types; for (auto param : f->params) { - GetType(param); + arg_types.push_back(GetType(param)); } Type rtype = GetType(f->body); - // Run solver using the currently known information - solver_.Solve(); - // Trying to resolve - Array arg_types; - for (size_t i = 0; i < f->params.size(); ++i) { - Type atype = solver_.Resolve(GetType(f->params[i])); - CHECK(atype.as() == nullptr) - << "Cannot resolve type of " << i - << "-th parameter of function at" << f->span; - arg_types.push_back(atype); + if (f->ret_type.defined()) { + rtype = this->Unify(f->ret_type, rtype, f->span); } - rtype = solver_.Resolve(rtype); - CHECK(rtype.as() == nullptr) - << "Cannot resolve return type of function at" << f->span; - // do not support constraint lifting for now. - return FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + return solver_.Resolve(ret); } }; @@ -380,7 +396,7 @@ class TypeInferencer::Resolver : public ExprMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { + : tmap_(tmap), solver_(solver) { } Expr VisitExpr_(const VarNode* op) final { @@ -525,6 +541,7 @@ Expr TypeInferencer::Infer(Expr expr) { GetType(expr); // Step 1: Solve the constraints. solver_.Solve(); + // Step 2: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); CHECK(WellFormed(resolved_expr)); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index e1efcbbdd0b91..caea3755b8f9e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -5,6 +5,7 @@ */ #include #include "type_solver.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -38,9 +39,298 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; +class TypeSolver::OccursChecker : public TypeVisitor { + public: + explicit OccursChecker(TypeSolver* solver, TypeNode* var) + : solver_(solver), var_(var), found_(false) {} + + bool Check(const Type& t) { + VisitType(t); + return found_; + } + + void VisitType_(const IncompleteTypeNode* op) override { + IncompleteType t = GetRef(op); + TypeNode* node = solver_->GetTypeNode(t); + found_ = found_ || (var_->FindRoot() == node->FindRoot()); + } + + private: + TypeSolver* solver_; + TypeNode* var_; + bool found_; +}; + +class TypeSolver::Unifier : public TypeFunctor { + public: + explicit Unifier(TypeSolver* solver) : solver_(solver) {} + + Type Unify(const Type& src, const Type& dst) { + // Known limitation + // - handle shape pattern matching + TypeNode* lhs = solver_->GetTypeNode(dst); + TypeNode* rhs = solver_->GetTypeNode(src); + + // do occur check so we don't create self-referencing structure + if (lhs->FindRoot() == rhs->FindRoot()) { + return lhs->resolved_type; + } + if (lhs->resolved_type.as()) { + CHECK(!CheckOccurs(lhs, rhs->resolved_type)) + << "Incomplete type " << lhs->resolved_type << " occurs in " + << rhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(lhs, rhs); + return rhs->resolved_type; + } else if (rhs->resolved_type.as()) { + CHECK(!CheckOccurs(rhs, lhs->resolved_type)) + << "Incomplete type " << rhs->resolved_type << " occurs in " + << lhs->resolved_type << ", cannot unify"; + solver_->MergeFromTo(rhs, lhs); + return lhs->resolved_type; + } else { + Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); + CHECK(resolved.defined()) + << "Unable to unify parent types: " + << lhs->resolved_type << " and " << rhs->resolved_type; + TypeNode* top = solver_->GetTypeNode(resolved); + solver_->MergeFromTo(lhs, top); + solver_->MergeFromTo(rhs, top); + return resolved; + } + } + + // Checks whether lhs (taken to be a type var) occurs in t, meaning + // there is a recursive equality constraint, which should be rejected. + // N.b.: A tautology like ?a = ?a is okay and should be checked for + // *before* calling this method + bool CheckOccurs(TypeNode* lhs, const Type& t) { + OccursChecker rc(solver_, lhs); + return rc.Check(t); + } + + // default: unify only if alpha-equal + Type VisitTypeDefault_(const Node* op, const Type& tn) override { + NodeRef nr = GetRef(op); + Type t1 = GetRef(nr.as_derived()); + if (!AlphaEqual(t1, tn)) { + return Type(nullptr); + } + return t1; + } + + Type VisitType_(const TupleTypeNode* op, const Type& tn) override { + const auto* ttn = tn.as(); + if (!ttn || op->fields.size() != ttn->fields.size()) { + return Type(nullptr); + } + + TupleType tt1 = GetRef(op); + TupleType tt2 = GetRef(ttn); + + std::vector new_fields; + for (size_t i = 0; i < tt1->fields.size(); i++) { + Type field = Unify(tt1->fields[i], tt2->fields[i]); + new_fields.push_back(field); + } + return TupleTypeNode::make(new_fields); + } + + Type VisitType_(const FuncTypeNode* op, const Type& tn) override { + const auto* ftn = tn.as(); + if (!ftn + || op->arg_types.size() != ftn->arg_types.size() + || op->type_params.size() != ftn->type_params.size() + || op->type_constraints.size() != ftn->type_constraints.size()) { + return Type(nullptr); + } + + // remap type vars so they match + Map subst_map; + for (size_t i = 0; i < op->type_params.size(); i++) { + subst_map.Set(ftn->type_params[i], op->type_params[i]); + } + + auto ft1 = GetRef(op); + auto ft2 = Downcast(Bind(GetRef(ftn), subst_map)); + + Type ret_type = Unify(ft1->ret_type, ft2->ret_type); + + std::vector arg_types; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); + arg_types.push_back(arg_type); + } + + std::vector type_constraints; + for (size_t i = 0; i < ft1->type_constraints.size(); i++) { + Type unified_constraint = Unify(ft1->type_constraints[i], + ft2->type_constraints[i]); + const auto* tcn = unified_constraint.as(); + CHECK(tcn) << "Two type constraints unified into a non-constraint?" + << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; + type_constraints.push_back(GetRef(tcn)); + } + + return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); + } + + private: + TypeSolver* solver_; +}; + +class TypeSolver::Resolver : public TypeMutator { + public: + explicit Resolver(TypeSolver* solver) : solver_(solver) {} + + Type Resolve(const Type& t) { + if (!t.defined()) { + return t; + } + return VisitType(t); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + auto* node = solver_->GetTypeNode(GetRef(op)); + return node->resolved_type; + } + + private: + TypeSolver* solver_; +}; + +// It ends up being more compact to simply have TypeFunctor { + public: + explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) + : solver_(solver), rels_(rels) {} + + // adds the relation node to t and all child types of t + void Propagate(const Type& t) { + VisitType(t); + } + + void UpdateRelSet(const Type& t) { + TypeNode* tnode = solver_->GetTypeNode(t); + for (auto* rel : *rels_) { + tnode->rel_set.insert(rel); + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + UpdateRelSet(t); + } + + void VisitType_(const TupleTypeNode* op) override { + TupleType tt = GetRef(op); + UpdateRelSet(tt); + + for (const Type& t : tt->fields) { + Propagate(t); + } + } + + void VisitType_(const FuncTypeNode* op) override { + FuncType ft = GetRef(op); + UpdateRelSet(ft); + + Propagate(ft->ret_type); + for (auto arg_type : ft->arg_types) { + Propagate(arg_type); + } + + for (auto type_param : ft->type_params) { + Propagate(type_param); + } + + for (auto type_cs : ft->type_constraints) { + Propagate(type_cs); + } + } + + private: + TypeSolver* solver_; + const std::unordered_set* rels_; +}; + +// similarly, we use TypeFunctor so we can use +// the default visitor case to avoid more overrides +class TypeSolver::Merger : public TypeFunctor { + public: + explicit Merger(TypeSolver* solver) : solver_(solver) {} + + // Merges src node to dst, ensures *all* type relations of all + // child nodes of src are transferred to dst. + void Merge(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + dst_ = dst; + VisitType(src->resolved_type); + // set parent at the end so later calls to GetTypeNode go back to src + src->parent = dst; + + // now propagate relations to child nodes, since change to + // a child node should update parent too + Propagator prop(solver_, &dst->rel_set); + prop.Propagate(dst->resolved_type); + } + + // Transfers any relations linked to t to the stored dst. + // Any unresolved relations are added back to the queue, since + // there is now new information + void TransferLinks(const Type& t) { + TypeNode* src = solver_->GetTypeNode(t); + if (src == dst_) return; + for (auto* rel : src->rel_set) { + // if the relation is not yet resolved, add to queue + if (!rel->resolved) { + solver_->AddToQueue(rel); + dst_->rel_set.insert(rel); + } + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + TransferLinks(t); + } + + void VisitType_(const TupleTypeNode* ttn) override { + auto tup = GetRef(ttn); + TransferLinks(tup); + + for (auto field : tup->fields) { + VisitType(field); + } + } + + void VisitType_(const FuncTypeNode* ftn) override { + auto func = GetRef(ftn); + TransferLinks(func); + + VisitType(func->ret_type); + for (auto arg : func->arg_types) { + VisitType(arg); + } + for (auto param : func->type_params) { + VisitType(param); + } + for (auto constraint : func->type_constraints) { + VisitType(constraint); + } + } + + private: + TypeSolver* solver_; + TypeNode* dst_; +}; + // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)) { + : reporter_(make_node(this)) { } // destructor @@ -54,31 +344,16 @@ TypeSolver::~TypeSolver() { } } +// merge src type node to dst +void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { + Merger merger(this); + merger.Merge(src, dst); +} + // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - // Known limitation - // - handle composite types whose component can be unknown. - // - handle shape pattern matching - TypeNode* lhs = GetTypeNode(dst); - TypeNode* rhs = GetTypeNode(src); - - // do occur check so we don't create self-referencing structure - if (lhs->FindRoot() == rhs->FindRoot()) { - return lhs->resolved_type; - } - if (lhs->resolved_type.as()) { - MergeFromTo(lhs, rhs); - return rhs->resolved_type; - } else if (rhs->resolved_type.as()) { - MergeFromTo(rhs, lhs); - return lhs->resolved_type; - } else { - lhs->parent = rhs; - CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type)) - << "Incompatible parent types in UF:" - << lhs->resolved_type << " and " << rhs->resolved_type; - return rhs->resolved_type; - } + Unifier unifier(this); + return unifier.Unify(dst, src); } // Add type constraint to the solver. @@ -96,9 +371,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - LinkNode* rlink = arena_.make >(); - rlink->value = rnode; - tnode->rel_list.Push(rlink); + std::unordered_set singleton { rnode }; + Propagator prop(this, &singleton); + prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. this->AddToQueue(rnode); @@ -110,12 +385,10 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { // Resolve a type in the solver context. Type TypeSolver::Resolve(const Type& type) { + Resolver resolver(this); auto it = tmap_.find(type); - if (it != tmap_.end()) { - return it->second->FindRoot()->resolved_type; - } else { - return type; - } + Type t = (it != tmap_.end()) ? it->second->FindRoot()->resolved_type : type; + return resolver.Resolve(t); } bool TypeSolver::Solve() { @@ -128,7 +401,7 @@ bool TypeSolver::Solve() { // update the relation with given evidence. Array args; for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { - args.push_back(tlink->value->FindRoot()->resolved_type); + args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); CHECK_LE(args.size(), rel->args.size()); } // call the function @@ -161,8 +434,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([solver](Type lhs, Type rhs) { - solver->Unify(lhs, rhs); + return TypedPackedFunc([solver](Type lhs, Type rhs) { + return solver->Unify(lhs, rhs); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 2f311c9b98102..b4635fdec331a 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -18,6 +18,7 @@ namespace relay { using common::LinkNode; using common::LinkedList; + /*! * \brief Interface of type solver used in type inference. * @@ -65,6 +66,11 @@ class TypeSolver { Type Unify(const Type& lhs, const Type& rhs); private: + class OccursChecker; + class Unifier; + class Resolver; + class Propagator; + class Merger; class Reporter; struct TypeNode; struct RelationNode; @@ -77,15 +83,15 @@ class TypeSolver { * that can unifies the same types to the name resolved_type. * * It also contains collection of links to related Relations, - * which is stored in rel_list. + * which is stored in rel_set. */ struct TypeNode { /*! \brief The final resolved type */ Type resolved_type; /*! \brief type node in the union find algorithm */ TypeNode* parent{nullptr}; - /*! \brief list of relations that is related to this type node */ - LinkedList rel_list; + /*! \brief set of relations that is related to this type node */ + std::unordered_set rel_set; /*! * \brief Find the root type node, perform path compression * \return The root type node. @@ -125,7 +131,7 @@ class TypeSolver { size_t num_resolved_rels_{0}; /*! \brief map from type node to types. */ std::unordered_map tmap_; - /*! \breif Internal queue to update the relation */ + /*! \brief Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ common::Arena arena_; @@ -163,22 +169,7 @@ class TypeSolver { * \param src The source operand * \param dst The dst operand. */ - void MergeFromTo(TypeNode* src, TypeNode* dst) { - if (src == dst) return; - src->parent = dst; - // move the link to the to dst - for (auto* rlink = src->rel_list.head; rlink != nullptr;) { - // store next pointer first before rlink get moved - auto* next = rlink->next; - // if the relation is not yet resolved - // send the relation to the new - if (!rlink->value->resolved) { - this->AddToQueue(rlink->value); - dst->rel_list.Push(rlink); - } - rlink = next; - } - } + void MergeFromTo(TypeNode* src, TypeNode* dst); }; } // namespace relay diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index b99d975135bea..403863c1d757b 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -12,105 +12,211 @@ namespace tvm { namespace relay { -// FreeTypeVar -class FreeTypeVarTVisitor : public TypeVisitor { +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class TypeVarTVisitor : public TypeVisitor { public: - FreeTypeVarTVisitor( - Array* free_vars, - std::unordered_set* bound_vars) - : free_vars_(free_vars), bound_vars_(bound_vars) { } + TypeVarTVisitor( + InsertionSet* type_vars, + InsertionSet* bound_type_vars) + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); - if (bound_vars_->count(var) == 0) { - free_vars_->push_back(var); - } + type_vars_->Insert(var); } void VisitType_(const FuncTypeNode* f) final { for (auto type_param : f->type_params) { - bound_vars_->insert(type_param); + type_vars_->Insert(type_param); + bound_type_vars_->Insert(type_param); } TypeVisitor::VisitType_(f); } private: - Array* free_vars_; - std::unordered_set* bound_vars_; + InsertionSet* type_vars_; + InsertionSet* bound_type_vars_; }; -class FreeTypeVarEVisitor : private ExprVisitor { +class TypeVarEVisitor : private ExprVisitor { public: - Array Find(const Expr& expr) { - this->VisitExpr(expr); - return free_vars_; + Array CollectFree() { + Array ret; + for (const auto& v : type_vars_.data) { + if (bound_type_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array CollectBound() { + Array ret; + for (const auto& v : bound_type_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array CollectAll() { + Array ret; + for (const auto& v : type_vars_.data) { + ret.push_back(v); + } + return ret; } - Array Find(const Type& type) { - this->VisitType(type); - return free_vars_; + Array Free(const Expr& expr) { + VisitExpr(expr); + return CollectFree(); + } + + Array Free(const Type& type) { + VisitType(type); + return CollectFree(); + } + + Array Bound(const Expr& expr) { + VisitExpr(expr); + return CollectBound(); + } + + Array Bound(const Type& type) { + VisitType(type); + return CollectBound(); + } + + Array All(const Expr& expr) { + VisitExpr(expr); + return CollectAll(); + } + + Array All(const Type& type) { + VisitType(type); + return CollectAll(); } void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { - bound_vars_.insert(tp); + type_vars_.Insert(tp); + bound_type_vars_.Insert(tp); } ExprVisitor::VisitExpr_(f); } void VisitType(const Type& t) final { - FreeTypeVarTVisitor(&free_vars_, &bound_vars_) + TypeVarTVisitor(&type_vars_, &bound_type_vars_) .VisitType(t); } private: - // The result list - Array free_vars_; - std::unordered_set bound_vars_; + InsertionSet type_vars_; + InsertionSet bound_type_vars_; }; -class FreeVarVisitor : protected ExprVisitor { +class VarVisitor : protected ExprVisitor { public: - Array Find(const Expr& expr) { + Array Free(const Expr& expr) { this->VisitExpr(expr); - return free_vars_; + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; } - void VisitExpr_(const VarNode* var) final { - if (bound_vars_.count(var) == 0) { - free_vars_.push_back(GetRef(var)); + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); } + return ret; + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { + vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { - bound_vars_.insert(param.operator->()); + MarkBounded(param); } VisitExpr(op->body); } void VisitExpr_(const LetNode* op) final { - bound_vars_.insert(op->var.operator->()); + MarkBounded(op->var); VisitExpr(op->value); VisitExpr(op->body); } private: - // The result list - Array free_vars_; - std::unordered_set bound_vars_; + InsertionSet vars_; + InsertionSet bound_vars_; }; tvm::Array FreeTypeVars(const Expr& expr) { - return FreeTypeVarEVisitor().Find(expr); + return TypeVarEVisitor().Free(expr); } tvm::Array FreeTypeVars(const Type& type) { - return FreeTypeVarEVisitor().Find(type); + return TypeVarEVisitor().Free(type); +} + +tvm::Array BoundTypeVars(const Expr& expr) { + return TypeVarEVisitor().Bound(expr); +} + +tvm::Array BoundTypeVars(const Type& type) { + return TypeVarEVisitor().Bound(type); +} + +tvm::Array AllTypeVars(const Expr& expr) { + return TypeVarEVisitor().All(expr); +} + +tvm::Array AllTypeVars(const Type& type) { + return TypeVarEVisitor().All(type); } tvm::Array FreeVars(const Expr& expr) { - return FreeVarVisitor().Find(expr); + return VarVisitor().Free(expr); +} + +tvm::Array BoundVars(const Expr& expr) { + return VarVisitor().Bound(expr); +} + +tvm::Array AllVars(const Expr& expr) { + return VarVisitor().All(expr); } TVM_REGISTER_API("relay._ir_pass.free_vars") @@ -118,16 +224,46 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") *ret = FreeVars(args[0]); }); +TVM_REGISTER_API("relay._ir_pass.bound_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = BoundVars(args[0]); + }); + +TVM_REGISTER_API("relay._ir_pass.all_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = AllVars(args[0]); + }); + TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as()) { + if (x.as_derived()) { *ret = FreeTypeVars(Downcast(x)); } else { *ret = FreeTypeVars(Downcast(x)); } }); +TVM_REGISTER_API("relay._ir_pass.bound_type_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as_derived()) { + *ret = BoundTypeVars(Downcast(x)); + } else { + *ret = BoundTypeVars(Downcast(x)); + } + }); + +TVM_REGISTER_API("relay._ir_pass.all_type_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as_derived()) { + *ret = AllTypeVars(Downcast(x)); + } else { + *ret = AllTypeVars(Downcast(x)); + } + }); + /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 385bde9740149..50aed4c57338e 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -6,13 +6,17 @@ TEST(Relay, SelfReference) { using namespace tvm; - auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); - auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); - auto x = relay::VarNode::make("x", type_a); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, Array{}); - auto fx = relay::CallNode::make(f, Array{ x }); + auto tensor_type = relay::TensorTypeNode::make({}, ::tvm::Bool()); + auto x = relay::VarNode::make("x", relay::Type()); + auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); + + auto y = relay::VarNode::make("y", tensor_type); + auto call = relay::CallNode::make(f, Array{ y }); + auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); - CHECK_EQ(type_fx->checked_type(), type_a); + + auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); + CHECK(AlphaEqual(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) { diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py deleted file mode 100644 index 151dbe1412bc6..0000000000000 --- a/tests/python/relay/test_pass_free_vars.py +++ /dev/null @@ -1,41 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.ir_pass import free_vars, free_type_vars - -def test_free_vars(): - ty = relay.TensorType([], "int32") - x = relay.Var("x", ty) - fvx = free_vars(x) - assert len(fvx) == 1 - assert fvx[0] == x - v = relay.Constant(tvm.nd.array(10)) - - let = relay.Let(x, v, x) - fvx = free_vars(let) - assert len(free_vars(let)) == 0 - f = relay.Function([x], x, ty) - assert len(free_vars(f)) == 0 - - -def test_tuple(): - t = relay.Var('t') - fv = free_vars(relay.Tuple([t, t])) - assert len(fv) == 1 - assert fv[0] == t - fv = free_vars(relay.TupleGetItem(t, 123)) - assert len(fv) == 1 - assert fv[0] == t - - -def test_free_type_vars(): - tp = relay.TypeVar("") - ty = relay.TupleType([tp, relay.TensorType([], "int32")]) - x = relay.Var("x", ty) - y = relay.Var("y") - let = relay.Let(x, y, x) - fvl = free_vars(let) - assert len(fvl) == 1 - assert fvl[0] == y - ftvl = free_type_vars(let) - assert len(ftvl) == 1 - assert ftvl[0] == tp diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py new file mode 100644 index 0000000000000..c8d3d6d14992a --- /dev/null +++ b/tests/python/relay/test_pass_vars.py @@ -0,0 +1,144 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import (free_vars, free_type_vars, + bound_vars, bound_type_vars, + all_vars, all_type_vars) + +def assert_vars_match(actual, expected): + assert len(actual) == len(expected) + for i in range(len(actual)): + assert actual[i] == expected[i] + + +def test_free_vars(): + ty = relay.TensorType([], "int32") + x = relay.Var("x", ty) + fvx = free_vars(x) + assert len(fvx) == 1 + assert fvx[0] == x + v = relay.Constant(tvm.nd.array(10)) + + let = relay.Let(x, v, x) + fvx = free_vars(let) + assert len(free_vars(let)) == 0 + f = relay.Function([x], x, ty) + assert len(free_vars(f)) == 0 + + +def test_free_vars_tuple(): + t = relay.Var('t') + fv = free_vars(relay.Tuple([t, t])) + assert len(fv) == 1 + assert fv[0] == t + fv = free_vars(relay.TupleGetItem(t, 123)) + assert len(fv) == 1 + assert fv[0] == t + + +def test_free_type_vars(): + tp = relay.TypeVar("") + ty = relay.TupleType([tp, relay.TensorType([], "int32")]) + x = relay.Var("x", ty) + y = relay.Var("y") + let = relay.Let(x, y, x) + fvl = free_vars(let) + assert len(fvl) == 1 + assert fvl[0] == y + ftvl = free_type_vars(let) + assert len(ftvl) == 1 + assert ftvl[0] == tp + + +def test_bound_vars(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + a = relay.Var("a") + + f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([]))) + assert_vars_match(bound_vars(f1), [x, y, z, a]) + + tup = relay.Tuple([x, y, z, a]) + assert len(bound_vars(tup)) == 0 + + f2 = relay.Function([x, y], relay.Tuple([x, y, z, a])) + assert_vars_match(bound_vars(f2), [x, y]) + + +def test_bound_type_vars(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + + ft1 = relay.FuncType([a], b, [a, b]) + bound_ft1 = bound_type_vars(ft1) + assert_vars_match(bound_type_vars(ft1), [a, b]) + + ft2 = relay.FuncType([], c, [a]) + assert_vars_match(bound_type_vars(ft2), [a]) + + tup_ty = relay.TupleType([a, b, c]) + assert len(bound_type_vars(tup_ty)) == 0 + + f1 = relay.Function([], relay.Tuple([]), type_params=[a, b]) + assert_vars_match(bound_type_vars(f1), [a, b]) + + f2 = relay.Function([], relay.Tuple([]), c) + assert len(bound_type_vars(f2)) == 0 + + x = relay.Var("x", a) + let1 = relay.Let(x, relay.Tuple([]), x) + assert len(bound_type_vars(let1)) == 0 + + let2 = relay.Let(x, relay.Function([], relay.Tuple([]), type_params=[b, c]), x) + assert_vars_match(bound_type_vars(let2), [b, c]) + + +def test_all_vars(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + + f1 = relay.Function([x, y], z) + assert_vars_match(all_vars(f1), [x, y, z]) + + f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z)) + assert_vars_match(all_vars(f2), [x, y, z]) + + f3 = relay.Function([x], relay.Tuple([y, z])) + assert_vars_match(all_vars(f3), [x, y, z]) + + tup = relay.Tuple([x, y, z]) + assert_vars_match(all_vars(tup), [x, y, z]) + + +def test_all_type_vars(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + + ft1 = relay.FuncType([b], c, [a]) + assert_vars_match(all_type_vars(ft1), [a, b, c]) + + ft2 = relay.FuncType([], relay.TupleType([a, b, c]), []) + assert_vars_match(all_type_vars(ft2), [a, b, c]) + + w = relay.Var("w") + x = relay.Var("x", a) + y = relay.Var("y", b) + z = relay.Var("z", c) + + f1 = relay.Function([x], y, b, [a]) + assert_vars_match(all_type_vars(f1), [a, b]) + + f2 = relay.Function([x], relay.Let(y, x, z)) + assert_vars_match(all_type_vars(f2), [a, b, c]) + + f3 = relay.Function([], relay.Tuple([x, y, z]), ret_type=relay.TupleType([a, b, c])) + assert_vars_match(all_type_vars(f3), [a, b, c]) + + f4 = relay.Function([w], relay.Tuple([]), type_params=[a, b, c]) + assert_vars_match(all_type_vars(f4), [a, b, c]) + + f5 = relay.Function([w], w) + assert len(all_type_vars(f5)) == 0 diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 06cb19639dcfd..ac4eb1b404dbc 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -23,7 +23,7 @@ def test_monomorphic_let(): x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) xchecked = relay.ir_pass.infer_type(sb.get()) - assert xchecked.checked_type == relay.scalar_type("float64") + assert xchecked.checked_type == relay.scalar_type("float64" ) def test_single_op(): @@ -41,14 +41,15 @@ def test_add_broadcast_op(): return x + y; } """ - pass - # x = relay.var('x', shape=(10, 4)) - # y = relay.var('y', shape=(5, 10, 1)) - # z = x + y - # func = relay.Function([x, y], z) - # ttype = relay.TensorType((5, 5, 5), 'float32') - # expected_ty = relay.FuncType([ttype, ttype], ttype) - # assert_has_type(func.to_func(), expected_ty) + x = relay.var('x', shape=(10, 4)) + y = relay.var('y', shape=(5, 10, 1)) + z = x + y + func = relay.Function([x, y], z) + t1 = relay.TensorType((10, 4), 'float32') + t2 = relay.TensorType((5, 10, 1), 'float32') + t3 = relay.TensorType((5, 10, 4), 'float32') + expected_ty = relay.FuncType([t1, t2], t3) + assert_has_type(func, expected_ty) def test_dual_op(): @@ -110,24 +111,17 @@ def f(n: i32, data: f32) -> f32 { assert "%3 = @f(%1, %2)" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) -# This currently fails and should pass under the type system. -# -# This test is to illustrate problem with our weak form of -# unification. -# - def test_incomplete_call(): - sb = ScopeBuilder() - x = relay.var('x', dtype='int32') + tt = relay.scalar_type('int32') + x = relay.var('x', tt) f = relay.var('f') - func = relay.Function([x, f], relay.Call(f, [x])) + func = relay.Function([x, f], relay.Call(f, [x]), tt) + + ft = relay.ir_pass.infer_type(func) + f_type = relay.FuncType([tt], tt) + assert ft.checked_type == relay.FuncType([tt, f_type], tt) - try: - relay.ir_pass.infer_type(func) - assert False - except tvm.TVMError as e: - assert True def test_tuple(): tp = relay.TensorType((10,)) @@ -136,6 +130,7 @@ def test_tuple(): assert (relay.ir_pass.infer_type(res).checked_type == relay.TupleType([tp, tp])) + def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -161,38 +156,26 @@ def test_type_args(): assert sh2[1].value == 10 -def test_self_reference(): - """ - Program: - def f(x) { - return x; - } - """ - a = relay.TypeVar("a") - x = relay.var("x", a) - sb = relay.ScopeBuilder() - - f = relay.Function([x], x) - fx = relay.Call(f, [x]) - assert relay.ir_pass.infer_type(x).checked_type == a - assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) - assert relay.ir_pass.infer_type(fx).checked_type == a - - -def test_global_var_cow_issue(): +def test_global_var_recursion(): mod = relay.Module({}) gv = relay.GlobalVar("foo") x = relay.var('x', shape=[]) - func = relay.Function([x], relay.Call(gv, [x]), - relay.TensorType([], 'float32')) + tt = relay.scalar_type('float32') + + func = relay.Function([x], relay.Call(gv, [x]), tt) mod[gv] = func + ft = relay.ir_pass.infer_type(gv, mod) + assert mod[ft].checked_type == relay.FuncType([tt], tt) + def test_equal(): i = relay.var('i', shape=[], dtype='int32') eq = op.equal(i, relay.const(0, dtype='int32')) - # This should fail .... - func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) + func = relay.Function([i], eq) + ft = relay.ir_pass.infer_type(func) + + assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) if __name__ == "__main__": @@ -204,8 +187,12 @@ def test_equal(): test_decl() test_recursion() test_tuple() + test_generalized_tuple() test_incomplete_call() + test_generalized_call() + test_call_with_type_args() test_free_expr() test_type_args() test_self_reference() - test_global_var_cow_issue() + test_global_var_recursion() + test_equal() diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index e8ff67756931b..1e2fed0af1f8e 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -1,5 +1,6 @@ import tvm from tvm import relay +from nose.tools import raises def make_rel(name, args, num_inputs=None, attrs=None): @@ -48,7 +49,170 @@ def test_backward_solving(): assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32") +def test_unify_tuple(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.TensorType((10, 20), "float32") + + tup1 = relay.ty.TupleType([t1, t2]) + tup2 = relay.ty.TupleType([t3, t3]) + + unified = solver.Unify(tup1, tup2) + assert unified == tup2 + + +def test_unify_functype(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + unit = relay.ty.TupleType([]) + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10,), "float32") + + ft1 = relay.ty.FuncType([t1, t2], t3) + ft2 = relay.ty.FuncType([tensor1, tensor2], unit) + + unified = solver.Unify(ft1, ft2) + assert unified == ft2 + + +def test_recursive_unify(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tensor1 = relay.ty.TensorType((10, 10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 20), "float32") + tensor3 = relay.ty.TensorType((10,), "float32") + + tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2]) + + ft1 = relay.ty.FuncType([tup1, t3], t3) + ft2 = relay.ty.FuncType([tup2, tensor3], tensor3) + + unified = solver.Unify(ft1, ft2) + assert unified == ft2 + + +def test_unify_vars_under_tuples(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([t1, t1]) + unified = solver.Unify(tup1, tup1) + assert unified == tup1 + + t2 = relay.ty.IncompleteType() + tup2 = relay.ty.TupleType([t2, t2]) + + tup3 = relay.ty.TupleType([t1, t2]) + tup4 = relay.ty.TupleType([t2, t1]) + unified = solver.Unify(tup3, tup4) + assert (unified == tup1 or unified == tup2) + + +def test_binding_over_typevars(): + solver = make_solver() + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + a = relay.ty.TypeVar('a') + b = relay.ty.TypeVar('b') + c = relay.ty.TypeVar('c') + d = relay.ty.TypeVar('d') + + ft1 = relay.ty.FuncType([t1], t2, [c, d]) + ft2 = relay.ty.FuncType([a], b, [a, b]) + unified = solver.Unify(ft1, ft2) + assert (unified == solver.Resolve(ft1)) + + +def test_recursive_backward_solving(): + solver = make_solver() + + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 1, 1), "float32") + tensor3 = relay.ty.TensorType((10,), "float32") + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3]) + solver.gen_type("Identity", [tup1], out=tup2) + + assert solver.Solve() + assert solver.Resolve(tup2) == tup1 + + +def test_backward_solving_after_child_update(): + solver = make_solver() + + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 1, 1), "float32") + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([t1, t2]) + tup2 = relay.ty.TupleType([t1, t3]) + + tup_concrete = relay.ty.TupleType([tensor1, tensor2]) + + t4 = solver.gen_type("Identity", [tup1]) + t5 = solver.gen_type("Identity", [tup2]) + + solver.gen_type("Identity", [t4], out=t5) + assert solver.Solve() + assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2 + assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2 + assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2 + + # updating the variables *inside* tup1 and tup2 should update t4 and t5 + solver.gen_type("Identity", [t1], out=tensor1) + solver.gen_type("Identity", [t2], out=tensor2) + assert solver.Solve() + assert solver.Resolve(t4) == tup_concrete + assert solver.Resolve(t5) == tup_concrete + +@raises(tvm._ffi.base.TVMError) +def test_incompatible_tuple_unification(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + tensor1 = relay.ty.TensorType((1, 2, 3), "float32") + tensor2 = relay.ty.TensorType((2, 3), "float32") + tensor3 = relay.ty.TensorType((3,), "float32") + + tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) + solver.Unify(tup1, tup2) + + +@raises(tvm._ffi.base.TVMError) +def test_bad_recursive_unification(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + solver.Unify(t1, relay.ty.TupleType([t1, t1])) + if __name__ == "__main__": test_bcast() test_backward_solving() + test_unify_tuple() + test_unify_functype() + test_recursive_unify() + test_unify_vars_under_tuples() + test_recursive_backward_solving() + test_backward_solving_after_child_update() + test_incompatible_tuple_unification() + test_bad_recursive_unification() From 52e3bd329da2bfd2f9d0fe9ac7eb3f3795a7e220 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Wed, 16 Jan 2019 17:21:40 -0800 Subject: [PATCH 03/16] [Runtime] Make runtime compatible with android ndk api 15 (#2446) --- src/runtime/cpu_device_api.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index d166a3a43dfab..6bd7022fd9f6d 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -10,6 +10,10 @@ #include #include "workspace_pool.h" +#ifdef __ANDROID__ +#include +#endif + namespace tvm { namespace runtime { class CPUDeviceAPI final : public DeviceAPI { @@ -28,10 +32,11 @@ class CPUDeviceAPI final : public DeviceAPI { #if _MSC_VER ptr = _aligned_malloc(nbytes, alignment); if (ptr == nullptr) throw std::bad_alloc(); -#elif defined(_LIBCPP_SGX_CONFIG) +#elif defined(_LIBCPP_SGX_CONFIG) || (defined(__ANDROID__) && __ANDROID_API__ < 16) ptr = memalign(alignment, nbytes); if (ptr == nullptr) throw std::bad_alloc(); #else + // posix_memalign is available in android ndk since __ANDROID_API__ >= 16 int ret = posix_memalign(&ptr, alignment, nbytes); if (ret != 0) throw std::bad_alloc(); #endif From 2a871f35acb0ae31bccf6747073603681c8044ff Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 17 Jan 2019 09:29:12 +0800 Subject: [PATCH 04/16] [RELAY][PASS] Support Negative Scale in FoldScaleAxis (#2426) * [RELAY][PASS] Support Negative Scale in FoldScaleAxis * Fix comment --- src/relay/pass/fold_scale_axis.cc | 328 ++++++++++-------- .../python/relay/test_pass_fold_scale_axis.py | 95 ++++- 2 files changed, 281 insertions(+), 142 deletions(-) diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 60df5d90a87cf..0cd46ff330e11 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -59,6 +59,36 @@ using runtime::TypedPackedFunc; */ using AxesSet = Array; +class Message; + +/*! + * \brief Message propogated during the prepare phase. + */ +class MessageNode : public RelayNode { + public: + /*! \brief Axes for scaling */ + AxesSet axes; + /*! + * \brief Whether folding requires the scale to be positive constant. This is necessary if some + * operators (e.g. Relu) is present. + */ + bool require_positive; + + static Message make(const AxesSet& axes, bool require_positive); + + static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; + TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode); +}; + +RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef); + +Message MessageNode::make(const AxesSet& axes, bool require_positive) { + auto n = make_node(); + n->axes = axes; + n->require_positive = require_positive; + return Message(n); +} + /*! * \brief Merge two axis set together by taking * intersection. @@ -88,14 +118,29 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { return ret; } +/*! + * \brief Merge two messages together by taking intersection. + * + * \param lhs The lhs message. + * \param rhs The rhs message. + * \return The result of intersection. + */ +Message Intersect(const Message& lhs, const Message& rhs) { + if (!lhs.defined()) return lhs; + if (!rhs.defined()) return rhs; + auto axes = Intersect(lhs->axes, rhs->axes); + return MessageNode::make(axes, lhs->require_positive || rhs->require_positive); +} + /*! * \brief Preparation function for pass scale forward. * \param call The call node. - * \param out_scale_axes Possible scaling on axes of the output. - * \return The result scaling on axes of the input. + * \param out_message Message from the output containing possible scaling on axes and whether + * positive scale is required. + * \return The message containing the result scaling on axes of the input. */ using FForwardPrep = runtime::TypedPackedFunc< - Array (const Call& call, const AxesSet& out_scale_axes)>; + Array (const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { @@ -126,16 +171,16 @@ class ScaledExprNode : public TempExprNode { using FForwardRewrite = TypedPackedFunc< Expr(const Call& ref_call, const Array& new_args, - const AxesSet& expeced_out_axes)>; + const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map + std::unordered_map Prepare(const Expr& body) { - this->Update(body, NullValue()); + this->Update(body, NullValue()); this->VisitExpr(body); // flist is added in the Post-DFS order // which is a special case of topological order. @@ -152,9 +197,9 @@ class ForwardPrep : private ExprVisitor { // The invoke list std::vector > flist_; // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // Update the message stored at node. - void Update(const Expr& node, const AxesSet& axes) { + void Update(const Expr& node, const Message& message) { // We run intersection of messages: // // %y = multiply(%x, %scale) @@ -167,9 +212,9 @@ class ForwardPrep : private ExprVisitor { // and the forward folding won't be triggered. const Node* key = node.get(); if (message_.count(key)) { - message_[key] = Intersect(message_[key], axes); + message_[key] = Intersect(message_[key], message); } else { - message_[key] = axes; + message_[key] = message; } } // Visitor pattern override. @@ -180,7 +225,7 @@ class ForwardPrep : private ExprVisitor { void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); auto flazy = [this, op] { - this->Update(op->body, NullValue()); + this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } @@ -193,23 +238,23 @@ class ForwardPrep : private ExprVisitor { Op::GetAttr("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); - AxesSet out_axes; + Message out_message; if (it != message_.end()) { - out_axes = it->second; + out_message = it->second; } else { - out_axes = NullValue(); + out_message = NullValue(); } // pass the message back to all the children it references. auto f = fprep.get(call->op, nullptr); if (f != nullptr) { - Array in_axes = f(GetRef(call), out_axes); - CHECK_EQ(in_axes.size(), call->args.size()); + Array in_messages = f(GetRef(call), out_message); + CHECK_EQ(in_messages.size(), call->args.size()); for (size_t i = 0; i < call->args.size(); ++i) { - this->Update(call->args[i], in_axes[i]); + this->Update(call->args[i], in_messages[i]); } } else { for (size_t i = 0; i < call->args.size(); ++i) { - this->Update(call->args[i], NullValue()); + this->Update(call->args[i], NullValue()); } } }; @@ -221,7 +266,7 @@ class ForwardPrep : private ExprVisitor { // do not support pass scale through tuple for now. auto flazy = [this, op]() { for (const Expr& field : op->fields) { - this->Update(field, NullValue()); + this->Update(field, NullValue()); } }; flist_.push_back(flazy); @@ -230,13 +275,13 @@ class ForwardPrep : private ExprVisitor { void VisitExpr_(const IfNode* op) { ExprVisitor::VisitExpr_(op); // do pass through condition - // by assigning NullValue + // by assigning NullValue // it means fuse signal cannot pass // through into these subexpressions. auto flazy = [this, op]() { - this->Update(op->cond, NullValue()); - this->Update(op->true_branch, NullValue()); - this->Update(op->false_branch, NullValue()); + this->Update(op->cond, NullValue()); + this->Update(op->true_branch, NullValue()); + this->Update(op->false_branch, NullValue()); }; flist_.push_back(flazy); } @@ -247,13 +292,16 @@ class ForwardPrep : private ExprVisitor { //---------------------------------------------- // Intermediate operators -Array ReluForwardPrep(const Call& call, AxesSet out) { - return {out}; +Array ReluForwardPrep(const Call& call, const Message& out_message) { + if (out_message.defined()) { + return {MessageNode::make(out_message->axes, true)}; + } + return {out_message}; } Expr ReluForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_axes) { + const Message& message) { const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d @@ -278,23 +326,23 @@ RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub -Array AddSubForwardPrep(const Call& call, AxesSet out_axes) { +Array AddSubForwardPrep(const Call& call, const Message& out_message) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - - auto none = NullValue(); - if (MatchBroadcastToLeftAxes(tlhs, trhs, out_axes)) { - return {out_axes, none}; - } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_axes)) { - return {none, out_axes}; - } else { - return {none, none}; + auto none = NullValue(); + if (out_message.defined()) { + if (MatchBroadcastToLeftAxes(tlhs, trhs, out_message->axes)) { + return {out_message, none}; + } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out_message->axes)) { + return {none, out_message}; + } } + return {none, none}; } Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_out_axes) { + const Message& message) { const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); if (!slhs && !srhs) return Expr(); @@ -342,9 +390,10 @@ RELAY_REGISTER_OP("subtract") // Multiply produces the scale-axis pair. Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_out_axes) { - if (!expected_out_axes.defined()) return Expr(); - if (expected_out_axes.size() == 0) return Expr(); + const Message& message) { + if (!message.defined()) return Expr(); + const auto& expected_out_axes = message->axes; + CHECK(expected_out_axes.defined() && expected_out_axes.size()); // TODO(tvm-team) allow same axes accumulation // not as important because it is less common in nn. const auto* slhs = new_args[0].as(); @@ -356,14 +405,15 @@ Expr MultiplyForwardRewrite(const Call& ref_call, Expr lhs = new_args[0]; Expr rhs = new_args[1]; auto rnode = make_node(); + if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && - IsAllPositiveConstant(rhs)) { + (!message->require_positive || IsAllPositiveConstant(rhs))) { rnode->value = lhs; rnode->scale = rhs; rnode->axes = expected_out_axes; return Expr(rnode); } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && - IsAllPositiveConstant(lhs)) { + (!message->require_positive || IsAllPositiveConstant(lhs))) { rnode->value = rhs; rnode->scale = lhs; rnode->axes = expected_out_axes; @@ -378,7 +428,7 @@ RELAY_REGISTER_OP("multiply") // Consumer operators // Conv2D send out requirement of axis folding. -Array Conv2DForwardPrep(const Call& call, AxesSet out) { +Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // TODO(tvm-team) support general data layout // by transforming weight const auto* param = call->attrs.as(); @@ -389,6 +439,7 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { int c_small_axis = data_layout.Indexof('c'); CHECK_GE(c_big_axis, 0); + Message none = NullValue(); AxesSet data_axes = NullValue(); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. @@ -403,13 +454,16 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { (param->groups == 1 || is_depthwise_conv2d)) { data_axes = {c_big_axis}; } - return {data_axes, NullValue()}; + if (data_axes.defined()) { + return {MessageNode::make(data_axes, false), none}; + } + return {none, none}; } // Conv2D consumes the scale axis during transformation. Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, - const AxesSet& expected_axes) { + const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as(); const auto* sweight = new_args[1].as(); @@ -458,11 +512,10 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(Expr data) { - auto expected_scale_axes = - ForwardPrep().Prepare(data); + auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ - auto it = expected_scale_axes.find(call.get()); - if (it != expected_scale_axes.end()) { + auto it = message.find(call.get()); + if (it != message.end()) { return it->second; } else { return NodeRef(nullptr); @@ -484,15 +537,16 @@ class BackwardTransformer; /*! * \brief Preparation function for for pass scale backward. * \param call The call node. - * \param in_scale_axes Allowed input scaling. - * \return The result scaling on axes of the input. + * \param in_messages Messages from the input containing allowed input scaling and whether + * positive scale is required. + * \return Message containing the result scaling on axes of the input. */ using FBackwardPrep = TypedPackedFunc< - AxesSet(const Call& call, const Array& in_scale_axes)>; + Message(const Call& call, const Array& in_messages)>; using FBackwardTransform = TypedPackedFunc< Expr(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer)>; @@ -503,7 +557,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); @@ -512,7 +566,7 @@ class BackwardPrep : private ExprVisitor { private: // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // reference counter of an internal expr std::unordered_map ref_counter_; // Visit the expression. @@ -527,18 +581,18 @@ class BackwardPrep : private ExprVisitor { // We only allow propagation of scale backward // if the expression is only referred by a single parent. if (rit->second != 1) return; - Array in_axes; + Array in_messages; for (Expr arg : call->args) { auto it = message_.find(arg.get()); if (it != message_.end()) { - in_axes.push_back(it->second); + in_messages.push_back(it->second); } else { - in_axes.push_back(NullValue()); + in_messages.push_back(NullValue()); } } - AxesSet out_axes = f(GetRef(call), in_axes); - if (out_axes.defined()) { - message_[call] = out_axes; + Message out_message = f(GetRef(call), in_messages); + if (out_message.defined()) { + message_[call] = out_message; } } }; @@ -549,7 +603,7 @@ class BackwardTransformerNode : public: // Run forward transform. Expr Fold(Expr expr) { - expected_scale_axes_ = BackwardPrep().Prepare(expr); + message_ = BackwardPrep().Prepare(expr); return this->Mutate(expr); } /*! @@ -560,12 +614,12 @@ class BackwardTransformerNode : * \param scale The scale applied to the axes. * \return The result of transformation. */ - Expr Transform(const Expr& expr, AxesSet axes, Expr scale) { + Expr Transform(const Expr& expr, Message message, Expr scale) { // NOTE: the result of Transform is memoized. if (const CallNode* call_node = expr.as()) { - return Transform(call_node, axes, scale); + return Transform(call_node, message, scale); } else { - CHECK(!axes.defined()) << "outstanding scale"; + CHECK(!message.defined()) << "outstanding scale"; return ExprMutator::VisitExpr(expr); } } @@ -585,14 +639,14 @@ class BackwardTransformerNode : return new_expr; } /*! - * \brief Get the expected axes on expr. + * \brief Get the message propogated to the expr. * \param expr The expresison. - * \return The expected axes. + * \return The message containing the expected axes and whether positive scale is required. */ - AxesSet GetExpectedAxes(const Expr& expr) const { - auto it = expected_scale_axes_.find(expr.get()); - if (it != expected_scale_axes_.end()) return it->second; - return NullValue(); + Message GetMessage(const Expr& expr) const { + auto it = message_.find(expr.get()); + if (it != message_.end()) return it->second; + return NullValue(); } // solver is not serializable. @@ -603,13 +657,13 @@ class BackwardTransformerNode : private: // Valid axes on each node. - std::unordered_map expected_scale_axes_; + std::unordered_map message_; // Override mutation of call. Expr VisitExpr_(const CallNode* call_node) final { - return Transform(call_node, NullValue(), NullValue()); + return Transform(call_node, NullValue(), NullValue()); } // Transform of CallNode. - Expr Transform(const CallNode* call_node, AxesSet axes, Expr scale); + Expr Transform(const CallNode* call_node, Message message, Expr scale); }; class BackwardTransformer : public NodeRef { @@ -625,7 +679,7 @@ class BackwardTransformer : public NodeRef { }; Expr BackwardTransformerNode::Transform( - const CallNode* call_node, AxesSet axes, Expr scale) { + const CallNode* call_node, Message message, Expr scale) { static const auto& ftransform = Op::GetAttr("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); @@ -636,13 +690,13 @@ Expr BackwardTransformerNode::Transform( return it->second; } Expr new_expr = f(GetRef(call_node), - axes, + message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { - CHECK(!axes.defined()) << "outstanding scale"; + CHECK(!message.defined()) << "outstanding scale"; return NormalCallTransform(call_node); } } @@ -653,19 +707,22 @@ Expr BackwardTransformerNode::Transform( //---------------------------------------------- // Intermediate operators -AxesSet ReluBackwardPrep(const Call& call, const Array& in_axes) { - return in_axes[0]; +Message ReluBackwardPrep(const Call& call, const Array& in_messages) { + if (in_messages[0].defined()) { + return MessageNode::make(in_messages[0]->axes, true); + } + return in_messages[0]; } Expr ReluBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { - if (!axes.defined()) { + if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } Expr input = transformer->Transform( - call->args[0], axes, scale); + call->args[0], message, scale); return CallNode::make(call->op, {input}, call->attrs, call->type_args); } @@ -682,64 +739,63 @@ RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub -AxesSet AddSubBackwardPrep(const Call& call, const Array& in_axes) { +Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); AttrsEqual equal; - if (in_axes[0].defined() && - MatchBroadcastToLeftAxes(tlhs, trhs, in_axes[0])) { - return in_axes[0]; - } else if (in_axes[1].defined() && - MatchBroadcastToLeftAxes(trhs, tlhs, in_axes[1])) { - return in_axes[1]; - } else if (in_axes[0].defined() && - in_axes[1].defined() && - equal(in_axes[0], in_axes[1]) && + if (in_messages[0].defined() && + MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { + return in_messages[0]; + } else if (in_messages[1].defined() && + MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { + return in_messages[1]; + } else if (in_messages[0].defined() && + in_messages[1].defined() && + equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. - return in_axes[0]; + return in_messages[0]; } else { - auto res = NullValue(); - CHECK(!res.defined()); + auto res = NullValue(); return res; } } Expr AddSubBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - if (!axes.defined()) { + if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } - AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); - AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); + Message lhs_message = transformer->GetMessage(call->args[0]); + Message rhs_message = transformer->GetMessage(call->args[1]); AttrsEqual equal; - if (lhs_axes.defined() && rhs_axes.defined()) { - CHECK(equal(lhs_axes, rhs_axes)); - CHECK(equal(axes, lhs_axes)); - Expr lhs = transformer->Transform(call->args[0], axes, scale); - Expr rhs = transformer->Transform(call->args[1], axes, scale); + if (lhs_message.defined() && rhs_message.defined()) { + CHECK(equal(lhs_message->axes, rhs_message->axes)); + CHECK(equal(message->axes, lhs_message->axes)); + Expr lhs = transformer->Transform(call->args[0], message, scale); + Expr rhs = transformer->Transform(call->args[1], message, scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); - } else if (lhs_axes.defined()) { - CHECK(equal(axes, lhs_axes)); - Expr lhs = transformer->Transform(call->args[0], axes, scale); + } else if (lhs_message.defined()) { + CHECK(equal(message->axes, lhs_message->axes)); + Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform( - call->args[1], NullValue(), NullValue()); + call->args[1], NullValue(), NullValue()); Expr rhs_scale = ExpandBiasToMatchAxis( - scale, tlhs->shape.size(), axes); + scale, tlhs->shape.size(), message->axes); rhs = Multiply(rhs, rhs_scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); - } else if (rhs_axes.defined()) { - CHECK(equal(axes, rhs_axes)); + } else if (rhs_message.defined()) { + CHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform( - call->args[0], NullValue(), NullValue()); - Expr rhs = transformer->Transform(call->args[1], axes, scale); + call->args[0], NullValue(), NullValue()); + Expr rhs = transformer->Transform(call->args[1], message, scale); Expr lhs_scale = ExpandBiasToMatchAxis( - scale, trhs->shape.size(), axes); + scale, trhs->shape.size(), message->axes); lhs = Multiply(lhs, lhs_scale); return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { @@ -763,29 +819,29 @@ RELAY_REGISTER_OP("subtract") // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { - CHECK(!axes.defined()) << "outstanding scale"; + CHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); - AxesSet lhs_axes = transformer->GetExpectedAxes(call->args[0]); - AxesSet rhs_axes = transformer->GetExpectedAxes(call->args[1]); - if (lhs_axes.defined() && lhs_axes.size() != 0) { + Message lhs_message = transformer->GetMessage(call->args[0]); + Message rhs_message = transformer->GetMessage(call->args[1]); + if (lhs_message.defined()) { + CHECK(lhs_message->axes.defined() && lhs_message->axes.size()); // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; - // Only propagate positive scaling. - if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) && - IsAllPositiveConstant(rhs)) { - return transformer->Transform(call->args[0], lhs_axes, rhs); + if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_message->axes, &rhs) && + (!lhs_message->require_positive || IsAllPositiveConstant(rhs))) { + return transformer->Transform(call->args[0], lhs_message, rhs); } - } else if (rhs_axes.defined() && rhs_axes.size() != 0) { - // Only propagate positive scaling. + } else if (rhs_message.defined()) { + CHECK(rhs_message->axes.defined() && rhs_message->axes.size()); Expr lhs = call->args[0]; - if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) && - IsAllPositiveConstant(lhs)) { - return transformer->Transform(call->args[1], rhs_axes, lhs); + if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_message->axes, &lhs) && + (!rhs_message->require_positive || IsAllPositiveConstant(lhs))) { + return transformer->Transform(call->args[1], rhs_message, lhs); } } return transformer->NormalCallTransform(call.operator->()); @@ -796,7 +852,7 @@ RELAY_REGISTER_OP("multiply") // Consumer operators // Conv2D send out requirement of axis folding. -AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { +Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) { const auto* param = call->attrs.as(); CHECK(param != nullptr); Layout kernel_layout(param->kernel_layout); @@ -817,18 +873,18 @@ AxesSet Conv2DBackwardPrep(const Call& call, const Array& in_axes) { kernel_layout.Indexof('i') < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { - return {c_big_axis}; + return MessageNode::make({c_big_axis}, false); } else { - return NullValue(); + return NullValue(); } } // Conv2D consumes the scale axis during transformation. Expr Conv2DBackwardTransform(const Call& call, - const AxesSet& axes, + const Message& message, const Expr& scale, const BackwardTransformer& transformer) { - if (!axes.defined()) { + if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } const auto* param = call->attrs.as(); @@ -841,8 +897,8 @@ Expr Conv2DBackwardTransform(const Call& call, // TODO(tvm-team) support general data layout CHECK_EQ(kernel_layout.Indexof('o'), -1); CHECK_EQ(kernel_layout.Indexof('i'), -1); - CHECK(axes.size() == 1 && - c_big_axis == axes[0]->value); + CHECK(message->axes.size() == 1 && + c_big_axis == message->axes[0]->value); int big_oc_axis = kernel_layout.Indexof('O'); // Check it must be depthwise or full conv2d. @@ -850,9 +906,9 @@ Expr Conv2DBackwardTransform(const Call& call, CHECK(param->groups == 1 || is_depthwise_conv2d); Expr data = transformer->Transform( - call->args[0], NullValue(), NullValue()); + call->args[0], NullValue(), NullValue()); Expr weight = transformer->Transform( - call->args[1], NullValue(), NullValue()); + call->args[1], NullValue(), NullValue()); // scale on input for deptwise. Expr wscale = ExpandBiasToMatchAxis( scale, kernel_layout.ndim(), {big_oc_axis}); diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 7d0089cfb3c41..dd9e7522fecfa 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -174,7 +174,6 @@ def check(shape, channels, in_scale): assert in_channels == channels weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.var("in_scale", shape=(in_channels,)) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) @@ -182,10 +181,52 @@ def check(shape, channels, in_scale): in_scale = relay.var("in_scale", shape=(4,)) check((2, 11, 10, 4), 4, in_scale) - in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32") + in_scale = relay.const(-_get_positive_scale((4,))) check((2, 11, 10, 4), 4, in_scale) +def test_fold_fwd_negative_scale(): + """Testcase of folding negative scale""" + def before(x, conv_weight, in_scale, channels): + args = [x, conv_weight] + x = relay.multiply(x, in_scale) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def expected(x, conv_weight, in_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight] + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + y = relay.nn.conv2d(x, + conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[1] + in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) + weight = relay.var("weight") + y1 = before(x, weight, in_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + y1_expected = expected(x, weight, in_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 4) + + def test_fold_bwd_simple(): """Simple testcase.""" def before(x, conv_weight, out_bias, out_scale, channels): @@ -223,7 +264,7 @@ def check(shape, channels): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -283,7 +324,7 @@ def check(shape, channels): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -356,7 +397,7 @@ def check(shape, channels): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels,1, 1))) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -411,7 +452,7 @@ def check(shape, channels, fbefore): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) @@ -448,13 +489,55 @@ def check(shape, channels, out_scale): check((4, 4, 10, 10), 4, out_scale) +def test_fold_bwd_negative_scale(): + """Testcase of folding negative scale""" + def before(x, conv_weight, out_scale, channels): + args = [x, conv_weight] + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.multiply(y, out_scale) + return relay.Function(args, y) + + def expected(x, conv_weight, out_scale, channels): + # use a fixed order of args so alpha equal check can pass + args = [x, conv_weight] + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + padding=(1, 1)) + return relay.Function(args, y) + + def check(shape, channels): + x = relay.var("x", shape=shape) + weight = relay.var("weight") + out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) + y1 = before(x, weight, out_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + type_dict = {x.name_hint:x.checked_type for x in y1.params} + weight = relay.var("weight", type_dict["weight"]) + y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) + y1_expected = expected(x, weight, out_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) + assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) + + check((2, 4, 10, 10), 8) + + if __name__ == "__main__": test_fold_fwd_simple() test_fold_fwd_dual_path() test_fold_fwd_fail() test_fold_fwd_relu_fail() + test_fold_fwd_negative_scale() test_fold_bwd_simple() test_fold_bwd_dual_path() test_fold_bwd_dual_consumer() test_fold_bwd_fail() test_fold_bwd_relu_fail() + test_fold_bwd_negative_scale() From d0f83664b84cd8437eabb179df56c5b57f8610df Mon Sep 17 00:00:00 2001 From: sf-wind Date: Wed, 16 Jan 2019 21:31:58 -0800 Subject: [PATCH 05/16] Avoid runtime exception when file doesn't exist (#2441) * Avoid runtime exception when file doesn't exist * Update the check based on feedback * Revert the old fix --- python/tvm/autotvm/tophub.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index b611f3cee0541..9ec9becc72459 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -77,7 +77,8 @@ def context(target, extra_files=None): for name in possible_names: name = _alias(name) if name in all_packages: - check_backend(name) + if not check_backend(name): + continue filename = "%s_%s.log" % (name, PACKAGE_VERSION[name]) best_context.load(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, filename)) @@ -98,6 +99,11 @@ def check_backend(backend): ---------- backend: str The name of backend. + + Returns + ---------- + success: bool + Whether the check is successful. """ backend = _alias(backend) assert backend in PACKAGE_VERSION, 'Cannot find backend "%s" in TopHub' % backend @@ -105,7 +111,7 @@ def check_backend(backend): version = PACKAGE_VERSION[backend] package_name = "%s_%s.log" % (backend, version) if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)): - return + return True if sys.version_info >= (3,): import urllib.request as urllib2 @@ -113,8 +119,10 @@ def check_backend(backend): import urllib2 try: download_package(package_name) + return True except urllib2.URLError as e: logging.warning("Failed to download tophub package for %s: %s", backend, e) + return False def download_package(package_name): From d274e4b3d33e8038296dfddbb4d9d1de8e0735aa Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 17 Jan 2019 08:57:16 -0800 Subject: [PATCH 06/16] [Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377) --- cmake/modules/ANTLR.cmake | 24 +- include/tvm/relay/base.h | 4 +- python/tvm/relay/_base.py | 5 + python/tvm/relay/_parser.py | 135 ++++++++-- python/tvm/relay/base.py | 10 + python/tvm/relay/grammar/Relay.g4 | 56 +++-- python/tvm/relay/parser.py | 12 +- src/relay/ir/base.cc | 14 ++ tests/python/relay/test_ir_parser.py | 359 ++++++++++++++------------- 9 files changed, 400 insertions(+), 219 deletions(-) create mode 100644 python/tvm/relay/_base.py diff --git a/cmake/modules/ANTLR.cmake b/cmake/modules/ANTLR.cmake index 72eb5925bda01..aede0098b7fb6 100644 --- a/cmake/modules/ANTLR.cmake +++ b/cmake/modules/ANTLR.cmake @@ -1,7 +1,15 @@ if(USE_ANTLR) - if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) - set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") + file(GLOB_RECURSE ANTLR4 + /usr/local/lib/antlr-*-complete.jar + /usr/local/Cellar/*antlr-*-complete.jar) + if(DEFINED ANTLR4) + # Get the first element of the list of antlr jars. + # Sort and reverse the list so the item selected is the highest + # version in lib or else in Cellar if no lib installation exists. + list(SORT ANTLR4) + list(REVERSE ANTLR4) + list(GET ANTLR4 0 ANTLR4) set(RELAY_PARSER_DIR ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) @@ -14,15 +22,21 @@ if(USE_ANTLR) ${RELAY_PARSER_DIR}/py3/RelayParser.py ${RELAY_PARSER_DIR}/py3/RelayLexer.py) + set(JAVA_HOME $ENV{JAVA_HOME}) + if (NOT DEFINED JAVA_HOME) + # Hack to get system to search for Java itself. + set(JAVA_HOME "/usr") + endif() + # Generate ANTLR grammar for parsing. add_custom_command(OUTPUT ${RELAY_PARSER} - COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 - COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 + COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 + COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 WORKING_DIRECTORY ${RELAY_PARSER_DIR}) add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) else() - message(FATAL_ERROR "Can't find ANTLR4!") + message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4}) endif() endif(USE_ANTLR) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f72f557a97652..f90acdc9400bb 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -108,7 +108,9 @@ class SourceName : public NodeRef { * \brief access the internal node container * \return the pointer to the internal node container */ - inline const SourceNameNode* operator->() const; + inline const SourceNameNode* operator->() const { + return static_cast(this->node_.get()); + } /*! * \brief Get an SourceName for a given operator name. diff --git a/python/tvm/relay/_base.py b/python/tvm/relay/_base.py new file mode 100644 index 0000000000000..b23655a0406a2 --- /dev/null +++ b/python/tvm/relay/_base.py @@ -0,0 +1,5 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface of expr function exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._base", __name__) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 2637e7e00f77a..c0455a3361e93 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -6,13 +6,17 @@ import sys from collections import deque -from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any +from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict + +import tvm from . import module +from .base import Span, SourceName from . import expr from . import ty from . import op + class ParseError(Exception): """Exception type for parse errors.""" @@ -76,22 +80,46 @@ def lookup(scopes, name): return val return None +def spanify(f): + """A decorator which attaches span information + to the value returned by calling `f`. + + Intended for use with the below AST visiting + methods. The idea is that after we do the work + of constructing the AST we attach Span information. + """ + + def _wrapper(*args, **kwargs): + # Assumes 0th arg is self and gets source_name from object. + sn = args[0].source_name + # Assumes 1st arg is an ANTLR parser context. + ctx = args[1] + ast = f(*args, **kwargs) + line, col = ctx.getSourceInterval() + sp = Span(sn, line, col) + ast.set_span(sp) + return ast + return _wrapper + # TODO(@jmp): Use https://stackoverflow.com/q/13889941 # to figure out how to get ANTLR4 to be more unhappy about syntax errors class ParseTreeToRelayIR(RelayVisitor): """Parse Relay text format into Relay IR.""" - def __init__(self): - # type: () -> None + def __init__(self, source_name): + # type: (str) -> None + self.source_name = source_name self.module = module.Module({}) # type: module.Module # Adding an empty scope allows naked lets without pain. - self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] - self.global_var_scope = deque() # type: Scope[expr.GlobalVar] - self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] + self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] + self.global_var_scope = deque() # type: Scope[expr.GlobalVar] + self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] + self.graph_expr = [] # type: List[expr.Expr] super(ParseTreeToRelayIR, self).__init__() + def enter_var_scope(self): # type: () -> None """Enter a new Var scope so it can be popped off later.""" @@ -146,20 +174,25 @@ def visitTerminal(self, node): node_type = node.getSymbol().type node_text = node.getText() + name = node_text[1:] # variables if node_type == RelayLexer.GLOBAL_VAR: - return lookup([self.global_var_scope], node_text[1:]) + return lookup(deque([self.global_var_scope]), node_text[1:]) elif node_type == RelayLexer.LOCAL_VAR: - name = node_text[1:] + # Remove the leading '%' and lookup the name. var = lookup(self.var_scopes, name) if var is None: raise ParseError("Couldn't resolve `{}`.".format(name)) - return var + elif node_type == RelayLexer.GRAPH_VAR: + try: + return self.graph_expr[int(name)] + except IndexError: + raise ParseError("Couldn't resolve `{}`".format(name)) # data types - elif node_type == RelayLexer.INT: + elif node_type == RelayLexer.NAT: return int(node_text) elif node_type == RelayLexer.FLOAT: return float(node_text) @@ -190,7 +223,7 @@ def getType_(self, ctx): return self.visit(ctx) def visitProg(self, ctx): - # type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] + # type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module] if ctx.defn(): self.visit_list(ctx.defn()) return self.module @@ -219,7 +252,7 @@ def visitScalarFloat(self, ctx): def visitScalarInt(self, ctx): # type: (RelayParser.ScalarIntContext) -> expr.Constant - return expr.const(self.visit(ctx.INT())) + return expr.const(self.visit(ctx.NAT())) def visitScalarBool(self, ctx): # type: (RelayParser.ScalarBoolContext) -> expr.Constant @@ -240,7 +273,7 @@ def visitTuple(self, ctx): return expr.Tuple(tup) # Currently doesn't support mutable sequencing. - def visitSeq(self, ctx): + def visitLet(self, ctx): # type: (RelayParser.SeqContext) -> expr.Let """Desugar various sequence constructs to Relay Let nodes.""" if ctx.MUT() is not None: @@ -253,7 +286,7 @@ def visitSeq(self, ctx): else: local_var = ctx.var().ident().LOCAL_VAR() if local_var is None: - raise ParseError('Only local ids may be used in `let`s.') + raise ParseError("Only local ids may be used in `let`s.") ident = local_var.getText()[1:] type_ = self.getType_(ctx.var().type_()) @@ -278,12 +311,14 @@ def visitBinOp(self, ctx): return relay_op(arg0, arg1) + @spanify def visitVar(self, ctx): # type: (RelayParser.VarContext) -> expr.Var + """Visit a single variable.""" ident = ctx.ident().LOCAL_VAR() if ident is None: - raise ParseError('Only local ids may be used in params.') + raise ParseError("Only local ids may be used in vars.") type_ = self.getType_(ctx.type_()) @@ -293,15 +328,33 @@ def visitVarList(self, ctx): # type: (RelayParser.VarListContext) -> List[expr.Var] return self.visit_list(ctx.var()) + # TODO: support a larger class of values than just Relay exprs + def visitAttr(self, ctx): + # type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr] + return (ctx.CNAME().getText(), self.visit(ctx.expr())) + + def visitAttrList(self, ctx): + # type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr] + return dict(self.visit_list(ctx.attr())) + + def visitArgList(self, + ctx # type: RelayParser.ArgListContext + ): + # type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]] + var_list = self.visit(ctx.varList()) if ctx.varList() else None + attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None + + return (var_list, attr_list) + def mk_func(self, ctx): - # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function + # type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() - var_list = self.visit(ctx.varList()) + var_list, attr_list = self.visit(ctx.argList()) ret_type = self.getType_(ctx.type_()) type_params = list(self.exit_type_param_scope()) @@ -311,22 +364,28 @@ def mk_func(self, ctx): body = self.visit(ctx.body()) self.exit_var_scope() - return expr.Function(var_list, body, ret_type, type_params) # type: ignore + attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None + + return expr.Function(var_list, body, ret_type, type_params, attrs) + @spanify def visitFunc(self, ctx): # type: (RelayParser.FuncContext) -> expr.Function return self.mk_func(ctx) + # TODO: how to set spans for definitions? + # @spanify def visitDefn(self, ctx): # type: (RelayParser.DefnContext) -> None ident = ctx.ident().GLOBAL_VAR() if ident is None: - raise ParseError('Only global ids may be used in `def`s.') + raise ParseError("Only global ids may be used in `def`s.") ident_name = ident.getText()[1:] ident = self.mk_global_var(ident_name) self.module[ident] = self.mk_func(ctx) + @spanify def visitCall(self, ctx): # type: (RelayParser.CallContext) -> expr.Call visited_exprs = self.visit_list(ctx.expr()) @@ -336,6 +395,7 @@ def visitCall(self, ctx): return expr.Call(func, args, None, None) + @spanify def visitIfElse(self, ctx): # type: (RelayParser.IfElseContext) -> expr.If """Construct a Relay If node. Creates a new scope for each branch.""" @@ -351,6 +411,27 @@ def visitIfElse(self, ctx): return expr.If(cond, true_branch, false_branch) + @spanify + def visitGraph(self, ctx): + # type: (RelayParser.GraphContext) -> expr.Expr + """Visit a graph variable assignment.""" + if ctx.ident().GRAPH_VAR() is None: + raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText())) + graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:]) + + self.enter_var_scope() + value = self.visit(ctx.expr(0)) + self.exit_var_scope() + + if graph_nid != len(self.graph_expr): + raise ParseError( + "Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ + "but got `%{}`".format(graph_nid)) + self.graph_expr.append(value) + + kont = self.visit(ctx.expr(1)) + return kont + # Types # pylint: disable=unused-argument @@ -428,8 +509,18 @@ def make_parser(data): token_stream = CommonTokenStream(lexer) return RelayParser(token_stream) -def fromtext(data): - # type: (str) -> Union[expr.Expr, env.Environment] +__source_name_counter__ = 0 + +def fromtext(data, source_name=None): + # type: (str, str) -> Union[expr.Expr, module.Module] """Parse a Relay program.""" + global __source_name_counter__ + + if source_name is None: + source_name = "source_file{0}".format(__source_name_counter__) + + if isinstance(source_name, str): + source_name = SourceName(source_name) + tree = make_parser(data).prog() - return ParseTreeToRelayIR().visit(tree) + return ParseTreeToRelayIR(source_name).visit(tree) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index c50013b199ac1..780d52863079b 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -4,6 +4,7 @@ from .._ffi.node import NodeBase, register_node as _register_tvm_node from . import _make from . import _expr +from . import _base NodeBase = NodeBase @@ -63,6 +64,9 @@ def astext(self, show_meta_data=True, annotate=None): """ return _expr.RelayPrint(self, show_meta_data, annotate) + def set_span(self, span): + _base.set_span(self, span) + @register_relay_node class Span(RelayNode): @@ -71,6 +75,12 @@ class Span(RelayNode): def __init__(self, source, lineno, col_offset): self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) +@register_relay_node +class SourceName(RelayNode): + """A identifier for a source location""" + + def __init__(self, name): + self.__init_handle_by_constructor__(_make.SourceName, name) @register_relay_node class Id(NodeBase): diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index cf6f9a7caa2b5..0a22062655026 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -1,5 +1,7 @@ grammar Relay; +SEMVER: 'v0.0.1' ; + // Lexing // comments WS : [ \t\n\r]+ -> skip ; @@ -20,7 +22,8 @@ NE: '!=' ; opIdent: CNAME ; GLOBAL_VAR: '@' CNAME ; -LOCAL_VAR: '%' CNAME ; +LOCAL_VAR: '%' CNAME; +GRAPH_VAR: '%' NAT; MUT: 'mut' ; @@ -31,13 +34,13 @@ BOOL_LIT // non-negative floats FLOAT - : INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5 - | INT EXP // 1e10 3e4 + : NAT '.' NAT EXP? // 1.35, 1.35E-9, 0.3, 4.5 + | NAT EXP // 1e10 3e4 ; // non-negative ints -INT: DIGIT+ ; -fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...] +NAT: DIGIT+ ; +fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...] CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; fragment LETTER: [a-zA-Z] ; @@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ; // Parsing // A Relay program is a list of global definitions or an expression. -prog: (defn* | expr) EOF ; +prog: SEMVER (defn* | expr) EOF ; // option: 'set' ident BOOL_LIT ; @@ -73,10 +76,11 @@ expr | 'if' '(' expr ')' body 'else' body # ifElse // sequencing - | 'let' MUT? var '=' expr ';' expr # seq - | 'let' MUT? var '=' '{' expr '}' ';' expr # seq + | 'let' MUT? var '=' expr ';' expr # let + | 'let' MUT? var '=' '{' expr '}' ';' expr # let // sugar for let %_ = expr; expr - | expr ';' expr # seq + | expr ';' expr # let + | ident '=' expr ';' expr # graph // mutable update // | ident '=' expr # writeRef @@ -84,16 +88,25 @@ expr | ident # identExpr | scalar # scalarExpr - // | expr '.' INT # project - // | 'debug' # debug + // | expr '.' NAT # project + // | 'debug' # debug ; -func: 'fn' varList ('->' type_)? body ; -defn: 'def' ident varList ('->' type_)? body ; +func: 'fn' '(' argList ')' ('->' type_)? body ; +defn: 'def' ident '(' argList ')' ('->' type_)? body ; + +argList + : varList + | attrList + | varList ',' attrList + ; -varList: '(' (var (',' var)*)? ')' ; +varList: (var (',' var)*)? ; var: ident (':' type_)? ; +attrList: (attr (',' attr)*)? ; +attr: CNAME '=' expr ; + // TODO(@jmp): for improved type annotations // returnAnno: (ident ':')? type_ ; @@ -110,7 +123,7 @@ type_ // | identType '[' (type_ (',' type_)*)? ']' # callType | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | '_' # incompleteType - | INT # intType + | NAT # intType ; shapeSeq @@ -123,20 +136,20 @@ shape : '(' shape ')' # parensShape // | type_ op=('*'|'/') type_ # binOpType // | type_ op=('+'|'-') type_ # binOpType - | INT # intShape + | NAT # intShape ; identType: CNAME ; -// Int8, Int16, Int32, Int64 -// UInt8, UInt16, UInt32, UInt64 -// Float16, Float32, Float64 -// Bool +// int8, int16, int32, int64 +// uint8, uint16, uint32, uint64 +// float16, float32, float64 +// bool body: '{' expr '}' ; scalar : FLOAT # scalarFloat - | INT # scalarInt + | NAT # scalarInt | BOOL_LIT # scalarBool ; @@ -144,4 +157,5 @@ ident : opIdent | GLOBAL_VAR | LOCAL_VAR + | GRAPH_VAR ; diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 51200343f147c..ba0b1aac063ed 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -1,8 +1,13 @@ """A parser for Relay's text format.""" from __future__ import absolute_import +from .. import register_func def enabled(): - """Is the parser enabled/Can we import the parser?""" + """Checks whether the parser is enabled, this allows users to + optionally support building the parser. + + We use this check before importing the parser. + """ try: # pylint: disable=unused-variable from tvm.relay import _parser @@ -11,7 +16,8 @@ def enabled(): except Exception: return False -def fromtext(data): +@register_func("relay.fromtext") +def fromtext(data, source_name=None): """Parse a Relay program.""" from tvm.relay import _parser - return _parser.fromtext(data) + return _parser.fromtext(data, source_name) diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 06593b6420f57..8df54883616aa 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) { return SourceName(GetSourceNameNode(name)); } +TVM_REGISTER_API("relay._make.SourceName") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SourceName::Get(args[0]); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SourceNameNode* node, tvm::IRPrinter* p) { p->stream << "SourceName(" << node->name << ", " << node << ")"; @@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(IdNode); +TVM_REGISTER_API("relay._base.set_span") +.set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef node_ref = args[0]; + auto rn = node_ref.as_derived(); + CHECK(rn); + Span sp = args[1]; + rn->span = sp; +}); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index d32750a4aafa6..08d4c430101b2 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -8,11 +8,12 @@ from typing import Union from functools import wraps if enabled(): - from tvm.relay._parser import ParseError - raises_parse_error = raises(ParseError) + raises_parse_error = raises(tvm._ffi.base.TVMError) else: raises_parse_error = lambda x: x +SEMVER = "v0.0.1" + BINARY_OPS = { "*": relay.multiply, "/": relay.divide, @@ -48,6 +49,10 @@ "float16x4", } +def parses_as(code, expr): + # type: (str, relay.Expr) -> bool + return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr) + def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() @@ -74,80 +79,80 @@ def wrapper(): @if_parser_enabled def test_comments(): - assert alpha_equal( - relay.fromtext(""" - // This is a line comment! - () - """), + assert parses_as( + """ + // This is a line comment! + () + """, UNIT ) - assert alpha_equal( - relay.fromtext(""" - /* This is a block comment! - This is still a block comment! - */ - () - """), + assert parses_as( + """ + /* This is a block comment! + This is still a block comment! + */ + () + """, UNIT ) @if_parser_enabled def test_int_literal(): - assert isinstance(relay.fromtext("1"), relay.Constant) - assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) + assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant) + assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray) - assert get_scalar(relay.fromtext("1")) == 1 - assert get_scalar(relay.fromtext("10")) == 10 - assert get_scalar(relay.fromtext("0")) == 0 - assert get_scalar(relay.fromtext("-100")) == -100 - assert get_scalar(relay.fromtext("-05")) == -5 + assert get_scalar(relay.fromtext(SEMVER+"1")) == 1 + assert get_scalar(relay.fromtext(SEMVER+"10")) == 10 + assert get_scalar(relay.fromtext(SEMVER+"0")) == 0 + assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100 + assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5 @if_parser_enabled def test_float_literal(): - assert get_scalar(relay.fromtext("1.0")) == 1.0 - assert isclose(get_scalar(relay.fromtext("1.56667")), 1.56667) - assert get_scalar(relay.fromtext("0.0")) == 0.0 - assert get_scalar(relay.fromtext("-10.0")) == -10.0 + assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667) + assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0 + assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0 # scientific notation - assert isclose(get_scalar(relay.fromtext("1e-1")), 1e-1) - assert get_scalar(relay.fromtext("1e+1")) == 1e+1 - assert isclose(get_scalar(relay.fromtext("1E-1")), 1E-1) - assert get_scalar(relay.fromtext("1E+1")) == 1E+1 - assert isclose(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1) - assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1 - assert isclose(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1) - assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1) + assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1) + assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1) + assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1 + assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1) + assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1 @if_parser_enabled def test_bool_literal(): - assert get_scalar(relay.fromtext("True")) == True - assert get_scalar(relay.fromtext("False")) == False + assert get_scalar(relay.fromtext(SEMVER+"True")) == True + assert get_scalar(relay.fromtext(SEMVER+"False")) == False @if_parser_enabled def test_negative(): - assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call) - assert get_scalar(relay.fromtext("--10")) == 10 - assert get_scalar(relay.fromtext("---10")) == -10 + assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call) + assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10 + assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10 @if_parser_enabled def test_bin_op(): for bin_op in BINARY_OPS.keys(): - assert alpha_equal( - relay.fromtext("1 {} 1".format(bin_op)), + assert parses_as( + "1 {} 1".format(bin_op), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) @if_parser_enabled def test_parens(): - assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1")) - assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)")) + assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1")) + assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)")) @if_parser_enabled def test_op_assoc(): - assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1")) - assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))")) + assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1")) + assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))")) @nottest @if_parser_enabled @@ -159,24 +164,24 @@ def test_vars(): # assert temp_var.name == "1" # var - var = relay.fromtext("let %foo = (); %foo") + var = relay.fromtext(SEMVER+"let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var - global_var = relay.fromtext("@foo") + global_var = relay.fromtext(SEMVER+"@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id - op = relay.fromtext("foo") + op = relay.fromtext(SEMVER+"foo") assert isinstance(op, relay.Op) assert op.name == "foo" @if_parser_enabled def test_let(): - assert alpha_equal( - relay.fromtext("let %x = 1; ()"), + assert parses_as( + "let %x = 1; ()", relay.Let( X, relay.const(1), @@ -184,18 +189,35 @@ def test_let(): ) ) + assert parses_as( + """ + let %x = 1; + let %y = 2; + () + """, + relay.Let( + X, + relay.const(1), + relay.Let( + Y, + relay.const(2), + UNIT + ) + ) + ) + @if_parser_enabled def test_seq(): - assert alpha_equal( - relay.fromtext("(); ()"), + assert parses_as( + "(); ()", relay.Let( _, UNIT, UNIT) ) - assert alpha_equal( - relay.fromtext("let %_ = { 1 }; ()"), + assert parses_as( + "let %_ = { 1 }; ()", relay.Let( X, relay.const(1), @@ -203,31 +225,48 @@ def test_seq(): ) ) +@if_parser_enabled +def test_graph(): + assert parses_as( + "%0 = (); %1 = 1; (%0, %0, %1)", + relay.Tuple([UNIT, UNIT, relay.const(1)]) + ) + + assert not parses_as( + "%0 = (); %1 = 1; (%0, %0, %1)", + relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) + ) + +@raises_parse_error +@if_parser_enabled +def test_graph_wrong_order(): + relay.fromtext(SEMVER+"%1 = (); %1") + @raises_parse_error @if_parser_enabled def test_let_global_var(): - relay.fromtext("let @x = 1; ()") + relay.fromtext(SEMVER+"let @x = 1; ()") @raises_parse_error @if_parser_enabled def test_let_op(): - relay.fromtext("let x = 1; ()") + relay.fromtext(SEMVER+"let x = 1; ()") @if_parser_enabled def test_tuple(): - assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) + assert parses_as("()", relay.Tuple([])) - assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) + assert parses_as("(0,)", relay.Tuple([relay.const(0)])) - assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) + assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) - assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) + assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) @if_parser_enabled def test_func(): # 0 args - assert alpha_equal( - relay.fromtext("fn () { 0 }"), + assert parses_as( + "fn () { 0 }", relay.Function( [], relay.const(0), @@ -237,8 +276,8 @@ def test_func(): ) # 1 arg - assert alpha_equal( - relay.fromtext("fn (%x) { %x }"), + assert parses_as( + "fn (%x) { %x }", relay.Function( [X], X, @@ -248,8 +287,8 @@ def test_func(): ) # 2 args - assert alpha_equal( - relay.fromtext("fn (%x, %y) { %x + %y }"), + assert parses_as( + "fn (%x, %y) { %x + %y }", relay.Function( [X, Y], relay.add(X, Y), @@ -259,8 +298,8 @@ def test_func(): ) # annotations - assert alpha_equal( - relay.fromtext("fn (%x: int32) -> int32 { %x }"), + assert parses_as( + "fn (%x: int32) -> int32 { %x }", relay.Function( [X_ANNO], X_ANNO, @@ -269,11 +308,17 @@ def test_func(): ) ) + # attributes + assert parses_as( + "fn (n=5) { () }", + relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5))) + ) + # TODO(@jmp): Crashes if %x isn't annnotated. -# @nottest @if_parser_enabled def test_defn(): id_defn = relay.fromtext( + SEMVER+ """ def @id(%x: int32) -> int32 { %x @@ -284,6 +329,7 @@ def @id(%x: int32) -> int32 { @if_parser_enabled def test_recursive_call(): id_defn = relay.fromtext( + SEMVER+ """ def @id(%x: int32) -> int32 { @id(%x) @@ -293,16 +339,14 @@ def @id(%x: int32) -> int32 { @if_parser_enabled def test_ifelse(): - assert alpha_equal( - relay.fromtext( + assert parses_as( """ if (True) { 0 } else { 1 } - """ - ), + """, relay.If( relay.const(True), relay.const(0), @@ -314,6 +358,7 @@ def test_ifelse(): @if_parser_enabled def test_ifelse_scope(): relay.fromtext( + SEMVER+ """ if (True) { let %x = (); @@ -328,13 +373,11 @@ def test_ifelse_scope(): def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %id = fn (%x) { %x }; 10 * %id(10) - """ - ), + """, relay.Let( id_func, relay.Function([X], X, None, []), @@ -344,13 +387,11 @@ def test_call(): # 0 args constant = relay.Var("constant") - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %constant = fn () { 0 }; %constant() - """ - ), + """, relay.Let( constant, relay.Function([], relay.const(0), None, []), @@ -360,13 +401,11 @@ def test_call(): # 1 arg id_var = relay.Var("id") - assert alpha_equal( - relay.fromtext( - """ - let %id = fn (%x) { %x }; - %id(1) - """ - ), + assert parses_as( + """ + let %id = fn (%x) { %x }; + %id(1) + """, relay.Let( id_var, relay.Function([X], X, None, []), @@ -376,13 +415,11 @@ def test_call(): # 2 args multiply = relay.Var("multiply") - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) - """ - ), + """, relay.Let( multiply, relay.Function( @@ -396,12 +433,10 @@ def test_call(): ) # anonymous function - assert alpha_equal( - relay.fromtext( + assert parses_as( """ (fn (%x) { %x })(0) - """ - ), + """, relay.Call( relay.Function( [X], @@ -415,45 +450,44 @@ def test_call(): ) ) + # TODO(@jmp): re-enable after sequence parsing improvements # curried function - curried_mult = relay.Var("curried_mult") - alpha_equal( - relay.fromtext( - """ - let %curried_mult = - fn (%x) { - fn (%y) { - %x * %y - } - }; - %curried_mult(0); - %curried_mult(0)(0) - """ - ), - relay.Let( - curried_mult, - relay.Function( - [X], - relay.Function( - [Y], - relay.multiply(X, Y), - None, - [] - ), - None, - [] - ), - relay.Let( - _, - relay.Call(curried_mult, [relay.const(0)], None, None), - relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) - ) - ) - ) + # curried_mult = relay.Var("curried_mult") + # assert parses_as( + # """ + # let %curried_mult = + # fn (%x) { + # fn (%y) { + # %x * %y + # } + # }; + # %curried_mult(0); + # %curried_mult(0)(0) + # """, + # relay.Let( + # curried_mult, + # relay.Function( + # [X], + # relay.Function( + # [Y], + # relay.multiply(X, Y), + # None, + # [] + # ), + # None, + # [] + # ), + # relay.Let( + # _, + # relay.Call(curried_mult, [relay.const(0)], None, None), + # relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) + # ) + # ) + # ) # op - alpha_equal( - relay.fromtext("abs(1)"), + assert parses_as( + "abs(1)", relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) @@ -461,8 +495,8 @@ def test_call(): @if_parser_enabled def test_incomplete_type(): - assert alpha_equal( - relay.fromtext("let %_ : _ = (); ()"), + assert parses_as( + "let %_ : _ = (); ()", relay.Let( _, UNIT, @@ -473,7 +507,7 @@ def test_incomplete_type(): @if_parser_enabled def test_builtin_types(): for builtin_type in TYPES: - relay.fromtext("let %_ : {} = (); ()".format(builtin_type)) + relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type)) @nottest @if_parser_enabled @@ -482,8 +516,8 @@ def test_call_type(): @if_parser_enabled def test_tensor_type(): - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), + assert parses_as( + "let %_ : Tensor[(), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, @@ -491,8 +525,8 @@ def test_tensor_type(): ) ) - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), + assert parses_as( + "let %_ : Tensor[(1,), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, @@ -500,8 +534,8 @@ def test_tensor_type(): ) ) - assert alpha_equal( - relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), + assert parses_as( + "let %_ : Tensor[(1, 1), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, @@ -511,12 +545,10 @@ def test_tensor_type(): @if_parser_enabled def test_function_type(): - assert alpha_equal( - relay.fromtext( - """ - let %_: fn () -> int32 = fn () -> int32 { 0 }; () - """ - ), + assert parses_as( + """ + let %_: fn () -> int32 = fn () -> int32 { 0 }; () + """, relay.Let( relay.Var("_", relay.FuncType([], int32, [], [])), relay.Function([], relay.const(0), int32, []), @@ -524,12 +556,10 @@ def test_function_type(): ) ) - assert alpha_equal( - relay.fromtext( - """ - let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () - """ - ), + assert parses_as( + """ + let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () + """, relay.Let( relay.Var("_", relay.FuncType([int32], int32, [], [])), relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), @@ -537,12 +567,10 @@ def test_function_type(): ) ) - assert alpha_equal( - relay.fromtext( - """ - let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () - """ - ), + assert parses_as( + """ + let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () + """, relay.Let( relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), @@ -552,11 +580,10 @@ def test_function_type(): @if_parser_enabled def test_tuple_type(): - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %_: () = (); () - """), + """, relay.Let( relay.Var("_", relay.TupleType([])), UNIT, @@ -564,11 +591,10 @@ def test_tuple_type(): ) ) - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %_: (int32,) = (0,); () - """), + """, relay.Let( relay.Var("_", relay.TupleType([int32])), relay.Tuple([relay.const(0)]), @@ -576,11 +602,10 @@ def test_tuple_type(): ) ) - assert alpha_equal( - relay.fromtext( + assert parses_as( """ let %_: (int32, int32) = (0, 1); () - """), + """, relay.Let( relay.Var("_", relay.TupleType([int32, int32])), relay.Tuple([relay.const(0), relay.const(1)]), From 985e7d72bfc0f4bf486ac87c4a67b2dbb36d23f4 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Fri, 18 Jan 2019 01:08:30 +0800 Subject: [PATCH 07/16] Update docs for some new modules (#2454) --- docs/api/python/contrib.rst | 5 +++++ docs/api/python/relay/frontend.rst | 4 ++++ docs/api/python/topi.rst | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index a58a3aa4fbefd..bc566759da99e 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -77,6 +77,11 @@ tvm.contrib.rocm .. automodule:: tvm.contrib.rocm :members: +tvm.contrib.sparse +~~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.sparse + :members: + tvm.contrib.spirv ~~~~~~~~~~~~~~~~~ diff --git a/docs/api/python/relay/frontend.rst b/docs/api/python/relay/frontend.rst index a418e042bf3de..054d3cecc1c54 100644 --- a/docs/api/python/relay/frontend.rst +++ b/docs/api/python/relay/frontend.rst @@ -5,3 +5,7 @@ tvm.relay.frontend .. automodule:: tvm.relay.frontend .. autofunction:: tvm.relay.frontend.from_mxnet + +.. autofunction:: tvm.relay.frontend.from_keras + +.. autofunction:: tvm.relay.frontend.from_onnx diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 450573e4c524e..856bad198e88e 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -144,6 +144,11 @@ topi.image ~~~~~~~~~~ .. autofunction:: topi.image.resize +topi.sparse +~~~~~~~~~~~ +.. autofunction:: topi.sparse.csrmv +.. autofunction:: topi.sparse.csrmm +.. autofunction:: topi.sparse.dense topi.generic ~~~~~~~~~~~~ From b374192b563cf76f0ff68f1a0272ed2cdabe2c59 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 17 Jan 2019 15:58:32 -0800 Subject: [PATCH 08/16] move fallback out of the build interface (#2456) --- python/tvm/relay/build_module.py | 24 +++++++++------------- tests/python/relay/test_pass_annotation.py | 15 +++++++------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index df2cc105bc688..51a4ff873e0a5 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -36,6 +36,7 @@ class BuildConfig(object): defaults = { "opt_level": 2, "add_pass": None, + "fallback_device": None, } def __init__(self, **kwargs): @@ -96,6 +97,10 @@ def build_config(**kwargs): add_pass: set of str Optimization pass to be added regardless of optimization level. + fallback_device : str or tvm.TVMContext + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. + Returns ------- config: BuildConfig @@ -192,8 +197,7 @@ def optimize(func, target, params=None): return func -def build(func, target=None, target_host=None, params=None, - fallback_device=None): +def build(func, target=None, target_host=None, params=None): """Build a function to run on TVM graph runtime. Parameters @@ -219,10 +223,6 @@ def build(func, target=None, target_host=None, params=None, Input parameters to the graph that do not change during inference time. Used for constant folding. - fallback_device : str or tvm.TVMContext, optional. - The fallback device. It is also used as the default device for - operators with no specified device. - Returns ------- graph_json : str @@ -239,8 +239,7 @@ def build(func, target=None, target_host=None, params=None, raise ValueError("Target is not set in env or passed as argument.") if isinstance(target, dict): - target, fallback_device = \ - _update_heterogeneous_inputs(target, fallback_device) + target, fallback_device = _update_heterogeneous_inputs(target) elif isinstance(target, (str, _target.Target)): target = _target.create(target) else: @@ -277,7 +276,7 @@ def build(func, target=None, target_host=None, params=None, return graph_json, mod, params -def _update_heterogeneous_inputs(target, fallback_device=None): +def _update_heterogeneous_inputs(target): """Update the target and fallback device required for heterogeneous compilation. CPU is used as the fallback device if it wasn't provided. Meanwhile, a CPU device type and "llvm" pair will be added to the target @@ -288,10 +287,6 @@ def _update_heterogeneous_inputs(target, fallback_device=None): target : dict of str(i.e. device/context name) to str/tvm.target.Target. A dict contains context to target pairs. - fallback_device : str or tvm.TVMContext, optional. - The fallback device. It is also used as the default device for - operators with no specified device. - Returns ------- device_target : dict of int to tvm.target.Target. @@ -305,6 +300,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None): "heterogeneous execution, but received %s." % type(target)) + fallback_device = BuildConfig.current.fallback_device if fallback_device is None: # cpu is used as the default fallback device when heterogeneous # execution is needed, but no fallback device is provided. @@ -315,7 +311,7 @@ def _update_heterogeneous_inputs(target, fallback_device=None): elif isinstance(fallback_device, TVMContext): fallback_device = fallback_device.device_type else: - raise ValueError("fallback_device expects the type of str or" + + raise ValueError("fallback_device expects the type of str or " + "TVMContext, but received %s." % type(fallback_device)) device_target = {} diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 1808ecb818a81..9f54a9fa949f4 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -3,7 +3,6 @@ import tvm from tvm import relay -from tvm.relay import testing from tvm.contrib import graph_runtime @@ -248,12 +247,14 @@ def get_func(): def test_runtime(target, device, func, fallback_device=None): params = {"x": x_data, "y": y_data} - with relay.build_config(opt_level=1): + config = {"opt_level": 1} + if fallback_device: + config["fallback_device"] = fallback_device + with relay.build_config(**config): graph, lib, params = relay.build( func, target, - params=params, - fallback_device=fallback_device) + params=params) contexts = [tvm.cpu(0), tvm.context(device)] mod = graph_runtime.create(graph, lib, contexts) mod.set_input(**params) @@ -367,13 +368,11 @@ def annotated(): test_runtime(target, device, annotated_func, fallback_device) def test_fallback_all_operators(device, tgt): - target = {"cpu": "llvm", device: tgt} - fallback_device = tvm.cpu(0) - + target = {device: tgt} annotated_func = get_func() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device) + test_runtime(target, device, annotated_func) for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), ("opencl", str(tvm.target.intel_graphics()))]: From 237dbf230eceaf40b9d5ec0cde178a9358183fd4 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 17 Jan 2019 21:22:59 -0800 Subject: [PATCH 09/16] [TUTORIAL] Introduce frontend folder (#2457) --- docs/conf.py | 1 + tutorials/frontend/README.txt | 4 ++++ tutorials/{relay => frontend}/from_onnx.py | 0 tutorials/nnvm/README.txt | 4 ++-- 4 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 tutorials/frontend/README.txt rename tutorials/{relay => frontend}/from_onnx.py (100%) diff --git a/docs/conf.py b/docs/conf.py index 7170038247031..1ac6c1d2b01fa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -190,6 +190,7 @@ def run_doxygen(folder): subsection_order = ExplicitOrder( ['../tutorials/language', + '../tutorials/frontend', '../tutorials/optimize', '../tutorials/autotvm', '../tutorials/dev', diff --git a/tutorials/frontend/README.txt b/tutorials/frontend/README.txt new file mode 100644 index 0000000000000..319506d21f8f7 --- /dev/null +++ b/tutorials/frontend/README.txt @@ -0,0 +1,4 @@ +.. _tutorial-frontend: + +Compile Deep Learning Models +---------------------------- diff --git a/tutorials/relay/from_onnx.py b/tutorials/frontend/from_onnx.py similarity index 100% rename from tutorials/relay/from_onnx.py rename to tutorials/frontend/from_onnx.py diff --git a/tutorials/nnvm/README.txt b/tutorials/nnvm/README.txt index 772953ce96ac2..334409cd8a288 100644 --- a/tutorials/nnvm/README.txt +++ b/tutorials/nnvm/README.txt @@ -1,4 +1,4 @@ .. _tutorial-nnvm: -Compile Deep Learning Models ----------------------------- +NNVM Compiler Tutorials +----------------------- From f7d0883d43be263abdf6b9318a933b6b807519d5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 18 Jan 2019 14:14:19 -0800 Subject: [PATCH 10/16] [DOCS][COMMUNITY] Improve code review guideline on API designs (#2459) --- docs/contribute/code_review.rst | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/docs/contribute/code_review.rst b/docs/contribute/code_review.rst index 3442969327036..77ba4cb659550 100644 --- a/docs/contribute/code_review.rst +++ b/docs/contribute/code_review.rst @@ -8,10 +8,37 @@ Open source code is maintained by a community with diverse backend, and it is ev Here are some checklists for code reviews, it is also helpful reference for contributors + Hold the Highest Standard ------------------------- The first rule for code reviewers is to always keep the highest standard, and do not approve code just to "be friendly". Good, informative critics each other learn and prevents technical debt in early stages. +Deliberate on API and Data Structures +------------------------------------- +A minimum and stable API is critical to the project’s life. A good API makes a huge difference. Always think very carefully about all the aspects including naming, argument definitions and behavior. + +When possible, pay more time and thoughts into the API design during code reviews. +Remember, it is easier to improve code implementation, but it is extremely hard to change an API. +We should do the same for data structures that are shared across modules(e.g. AST). +When uncertain, start a conversation with more developers. + +Here are some useful principles for designing APIs: + +- Be consistent with existing well-known package’s APIs if the feature overlap. + For example, tensor operation APIs should always be consistent with the numpy API. +- Be consistent with existing APIs in the same project. + For example, we should use the same argument ordering across all the optimization passes, + so there is no "surprise" when using them. +- Think about whether the API will change in the future. + For example, we will have more options like loop_unrolling and device placement policy + as we add more optimizations in build. We can package optimization knobs into a build + configuration object. So that the build API is stable over time. +- Write down documents. Documents are mandatory for APIs and sometimes writing documents helps + us to think about whether we need clarification. +- Minimum. Think about how many lines of code a user has to write to use the API. + Remove layers of abstraction when possible. + + Ensure Test Coverage -------------------- Each new change of features should introduce test cases, bug fixes should include regression tests that prevent the problem from happening again. @@ -20,10 +47,6 @@ Documentations are Mandatory ---------------------------- Documentation is usually a place we overlooked, new functions or change to a function should be directly updated in documents. A new feature is meaningless without documentation to make it accessible. See more at :ref:`doc_guide` -Deliberate on User-facing API ------------------------------ -A good, minimum and stable API is critical to the project’s life. A good API makes a huge difference. Always think very carefully about all the aspects including naming, arguments definitions and behavior. One good rule to check is to be consistent with existing well-known package’s APIs if the feature overlap. For example, tensor operation APIs should always be consistent with the numpy. - Minimum Dependency ------------------ Always be cautious in introducing dependencies. While it is important to reuse code and not reinventing the wheel, dependencies can increase burden of users in deployment. A good design principle only depends on the part when a user actually use it. From 02455306858a1161089f21536810d20726a0bce0 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 18 Jan 2019 19:24:47 -0800 Subject: [PATCH 11/16] [COMMUNITY] @junrushao1994 -> Reviewer (#2463) --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 23d22686705b4..4370c7405cec8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -49,6 +49,7 @@ We do encourage everyone to work anything they are interested in. - [Jared Roesch](https://github.com/jroesch): @jroesch - [Siva](https://github.com/srkreddy1238): @srkreddy1238 - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel +- [Junru Shao](https://github.com/junrushao1994): @junrushao1994 - [Haichen Shen](https://github.com/icemelon9): @icemelon9 - [Alex Weaver](https://github.com/alex-weaver): @alex-weaver - [Yao Wang](https://github.com/kevinthesun): @kevinthesun From 5194da65bf4266818305c250e171ede5e7279a44 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Sat, 19 Jan 2019 10:02:57 -0800 Subject: [PATCH 12/16] Add gluoncv installation (#2464) --- docker/Dockerfile.ci_gpu | 3 +++ docker/Dockerfile.demo_cpu | 2 +- docker/Dockerfile.demo_gpu | 2 +- docker/install/ubuntu_install_gluoncv.sh | 1 + 4 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 docker/install/ubuntu_install_gluoncv.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index a8bbc60bb3ae8..fa15113289d09 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -46,6 +46,9 @@ RUN bash /install/ubuntu_install_opengl.sh COPY install/ubuntu_install_mxnet.sh /install/ubuntu_install_mxnet.sh RUN bash /install/ubuntu_install_mxnet.sh +COPY install/ubuntu_install_gluoncv.sh /install/ubuntu_install_gluoncv.sh +RUN bash /install/ubuntu_install_gluoncv.sh + COPY install/ubuntu_install_coreml.sh /install/ubuntu_install_coreml.sh RUN bash /install/ubuntu_install_coreml.sh diff --git a/docker/Dockerfile.demo_cpu b/docker/Dockerfile.demo_cpu index 0778b0a28784a..e8a20dec8fc17 100644 --- a/docker/Dockerfile.demo_cpu +++ b/docker/Dockerfile.demo_cpu @@ -21,7 +21,7 @@ RUN echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main \ RUN pip3 install matplotlib Image Pillow jupyter[notebook] # Deep learning frameworks -RUN pip3 install mxnet tensorflow keras +RUN pip3 install mxnet tensorflow keras gluoncv # Build TVM COPY install/install_tvm_cpu.sh /install/install_tvm_cpu.sh diff --git a/docker/Dockerfile.demo_gpu b/docker/Dockerfile.demo_gpu index d20293c4ed3df..6fe6b2ae8b09d 100644 --- a/docker/Dockerfile.demo_gpu +++ b/docker/Dockerfile.demo_gpu @@ -21,7 +21,7 @@ RUN echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main \ RUN pip3 install matplotlib Image Pillow jupyter[notebook] # Deep learning frameworks -RUN pip3 install mxnet tensorflow keras +RUN pip3 install mxnet tensorflow keras gluoncv # Build TVM COPY install/install_tvm_gpu.sh /install/install_tvm_gpu.sh diff --git a/docker/install/ubuntu_install_gluoncv.sh b/docker/install/ubuntu_install_gluoncv.sh new file mode 100644 index 0000000000000..0ca1a34cbc242 --- /dev/null +++ b/docker/install/ubuntu_install_gluoncv.sh @@ -0,0 +1 @@ +pip3 install gluoncv From eca4f88a74825be7e90f962a813717a32192c826 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 19 Jan 2019 10:05:25 -0800 Subject: [PATCH 13/16] Fix broadcast add and subtract grad (#2465) --- CMakeLists.txt | 2 +- python/tvm/relay/op/_tensor.py | 14 ++++++--- tests/python/relay/test_ad.py | 56 ++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 363b2056a87ad..8765a3346069e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,7 +83,7 @@ else(MSVC) include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - add_compile_options(-Wall -fPIC -std=c++11) + add_compile_options(-O0 -Wall -fPIC -std=c++11) else() set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}") diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 240e7fffd3289..d9b5e2e89ce05 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -5,17 +5,21 @@ from .op import register_compute, register_schedule, register_pattern from .op import register_gradient from .op import schedule_injective, OpPattern +from .transform import collapse_sum_like +from .tensor import negative + def add_grad(orig, grad): - from tvm.relay import op - return [op.broadcast_to_like(grad, orig.args[0]), op.broadcast_to_like(grad, orig.args[1])] + return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])] + register_gradient("add", add_grad) + def subtract_grad(orig, grad): - from tvm.relay import op - return [op.broadcast_to_like(grad, orig.args[0]), - op.broadcast_to_like(op.negative(grad), orig.args[1])] + return [collapse_sum_like(grad, orig.args[0]), + collapse_sum_like(negative(grad), orig.args[1])] + register_gradient("subtract", subtract_grad) diff --git a/tests/python/relay/test_ad.py b/tests/python/relay/test_ad.py index 7844236907c41..6b5d0e7769343 100644 --- a/tests/python/relay/test_ad.py +++ b/tests/python/relay/test_ad.py @@ -69,8 +69,64 @@ def test_sub(): np.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy())) +def test_broadcast_add(): + shape1 = (3, 4, 1) + shape2 = (1, 5) + dtype = 'float32' + x_nd = rand(dtype, *shape1) + y_nd = rand(dtype, *shape2) + x_np = x_nd.asnumpy() + y_np = y_nd.asnumpy() + expected_forward = x_np + y_np + t1 = relay.TensorType(shape1, dtype) + t2 = relay.TensorType(shape2, dtype) + x = relay.var("x", t1) + y = relay.var("y", t2) + func = relay.Function([x, y], x + y) + full_func = relay.ir_pass.infer_type(gradient(func)) + assert full_func.checked_type == relay.FuncType([t1, t2], + relay.TupleType([relay.TensorType(expected_forward.shape, dtype), + relay.TupleType([t1, t2])])) + ex = create_executor() + forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd) + np.testing.assert_allclose(forward.asnumpy(), expected_forward) + np.testing.assert_allclose(grad_x.asnumpy(), + np.ones_like(expected_forward).sum(axis=2, keepdims=True)) + np.testing.assert_allclose(grad_y.asnumpy(), + np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0)) + + +def test_broadcast_subtract(): + shape1 = (3, 4, 1) + shape2 = (1, 5) + dtype = 'float32' + x_nd = rand(dtype, *shape1) + y_nd = rand(dtype, *shape2) + x_np = x_nd.asnumpy() + y_np = y_nd.asnumpy() + expected_forward = x_np - y_np + t1 = relay.TensorType(shape1, dtype) + t2 = relay.TensorType(shape2, dtype) + x = relay.var("x", t1) + y = relay.var("y", t2) + func = relay.Function([x, y], x - y) + full_func = relay.ir_pass.infer_type(gradient(func)) + assert full_func.checked_type == relay.FuncType([t1, t2], + relay.TupleType([relay.TensorType(expected_forward.shape, dtype), + relay.TupleType([t1, t2])])) + ex = create_executor() + forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd) + np.testing.assert_allclose(forward.asnumpy(), expected_forward) + np.testing.assert_allclose(grad_x.asnumpy(), + np.ones_like(expected_forward).sum(axis=2, keepdims=True)) + np.testing.assert_allclose(grad_y.asnumpy(), + -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0)) + + if __name__ == "__main__": test_id() test_add() test_temp_add() test_sub() + test_broadcast_add() + test_broadcast_subtract() From 45456b1495048883fa3b3f2095c6747b8188db85 Mon Sep 17 00:00:00 2001 From: Kuan Hsh Chen Date: Sun, 20 Jan 2019 03:41:30 +0800 Subject: [PATCH 14/16] Fix typo (#2467) --- nnvm/python/nnvm/symbol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index 6997ecc64654f..0acacb247a2c9 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -36,7 +36,7 @@ class Symbol(SymbolBase): - """Symbol is basic operation unit for symbolic graph compostion.""" + """Symbol is basic operation unit for symbolic graph composition.""" # disable dictionary storage, also do not have parent type. __slots__ = [] From e4b9f986dab8c48ba109a52106565fc4be6b67c4 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 20 Jan 2019 11:01:28 -0800 Subject: [PATCH 15/16] Frontend before tensor expression --- docs/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1ac6c1d2b01fa..1166d73e9264c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -189,8 +189,8 @@ def run_doxygen(folder): gallery_dirs = ["tutorials", "vta/tutorials"] subsection_order = ExplicitOrder( - ['../tutorials/language', - '../tutorials/frontend', + ['../tutorials/frontend', + '../tutorials/language', '../tutorials/optimize', '../tutorials/autotvm', '../tutorials/dev', From 0806b69e3fb136226fa1dafad00bd2c606cc998d Mon Sep 17 00:00:00 2001 From: llyfacebook <34827865+llyfacebook@users.noreply.github.com> Date: Sun, 20 Jan 2019 18:08:03 -0800 Subject: [PATCH 16/16] [RPC] Add the IPV6 support for server side auto tuning (#2462) * use IPV6 instead of IPV4 * backward compatible * add error report * fix linter * more linter * fix the python2 api --- python/tvm/exec/rpc_server.py | 2 +- python/tvm/rpc/base.py | 7 +++- python/tvm/rpc/proxy.py | 5 +-- python/tvm/rpc/server.py | 4 +-- python/tvm/rpc/tracker.py | 4 +-- src/common/socket.h | 53 ++++++++++++++++++++++++------ src/runtime/rpc/rpc_socket_impl.cc | 2 +- 7 files changed, 58 insertions(+), 19 deletions(-) diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 5998e9ffe6ac0..73c943366b4ce 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -12,7 +12,7 @@ def main(args): """Main function""" if args.tracker: - url, port = args.tracker.split(":") + url, port = args.tracker.rsplit(":", 1) port = int(port) tracker_addr = (url, port) if not args.key: diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index 5731eb870a9d0..294b5c2e4060d 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -42,6 +42,11 @@ class TrackerCode(object): RPC_SESS_MASK = 128 +def get_addr_family(addr): + res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) + return res[0][0] + + def recvall(sock, nbytes): """Receive all nbytes from socket. @@ -142,7 +147,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5): tstart = time.time() while True: try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket.socket(get_addr_family(addr), socket.SOCK_STREAM) sock.connect(addr) return sock except socket.error as sock_err: diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index ad9f189f4a78e..cefffbfa9668e 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -298,7 +298,8 @@ def _update_tracker(self, period_update=False): """Update information on tracker.""" try: if self._tracker_conn is None: - self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._tracker_conn = socket.socket(base.get_addr_family(self._tracker_addr), + socket.SOCK_STREAM) self._tracker_conn.connect(self._tracker_addr) self._tracker_conn.sendall(struct.pack("ai_family == AF_INET) - << "Does not support IPv6"; - memcpy(&addr, res->ai_addr, res->ai_addrlen); - addr.sin_port = htons(port); + switch (res->ai_family) { + case AF_INET: { + sockaddr_in *addr4 = reinterpret_cast(&addr); + memcpy(addr4, res->ai_addr, res->ai_addrlen); + addr4->sin_port = htons(port); + addr4->sin_family = AF_INET; + } + break; + case AF_INET6: { + sockaddr_in6 *addr6 = reinterpret_cast(&addr); + memcpy(addr6, res->ai_addr, res->ai_addrlen); + addr6->sin6_port = htons(port); + addr6->sin6_family = AF_INET6; + } + break; + default: + CHECK(false) << "cannot decode address"; + } freeaddrinfo(res); } /*! \brief return port of the address */ int port() const { - return ntohs(addr.sin_port); + return ntohs((addr.ss_family == AF_INET6)? \ + reinterpret_cast(&addr)->sin6_port : \ + reinterpret_cast(&addr)->sin_port); + } + /*! \brief return the ip address family */ + int ss_family() const { + return addr.ss_family; } /*! \return a string representation of the address */ std::string AsString() const { std::string buf; buf.resize(256); + + const void *sinx_addr = nullptr; + if (addr.ss_family == AF_INET6) { + const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; + sinx_addr = reinterpret_cast(&addr6); + } else if (addr.ss_family == AF_INET) { + const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; + sinx_addr = reinterpret_cast(&addr4); + } else { + CHECK(false) << "illegal address"; + } + #ifdef _WIN32 - const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, + const char *s = inet_ntop(addr.ss_family, sinx_addr, &buf[0], buf.length()); #else - const char *s = inet_ntop(AF_INET, &addr.sin_addr, + const char *s = inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast(buf.length())); #endif CHECK(s != nullptr) << "cannot decode address"; @@ -294,7 +327,7 @@ class TCPSocket : public Socket { * \param af domain */ void Create(int af = PF_INET) { - sockfd = socket(PF_INET, SOCK_STREAM, 0); + sockfd = socket(af, SOCK_STREAM, 0); if (sockfd == INVALID_SOCKET) { Socket::Error("Create"); } diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 6b2fa6c1f6083..bf8bce9d0f7db 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -43,7 +43,7 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string key) { common::TCPSocket sock; common::SockAddr addr(url.c_str(), port); - sock.Create(); + sock.Create(addr.ss_family()); CHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; // hand shake