From 6756b9cabbee6774dbd379130ee0f956b8c6b87b Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 21 Mar 2019 15:14:15 -0700 Subject: [PATCH] add document lint lint save save add more case save error lint lint commit do lint save fix lint wrap it back as func lint save remove dead comment fix style fix lint Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame address review feedback pe now handle freevar. as a result preserving function is now trivial. test add basic test, implement pretty printing for generic function test lint fix segfault save save do test fix another error address comment commit save address review feedback add test for invalidate, fix error in lookup rename cont to boduy fix error and add regression test Update src/relay/pass/partial_eval.cc Co-Authored-By: MarisaKirisame fix error, add test case fix lint remove extra line fix some error pe commit save save save save save (pe/dce broken) [DOCKER] Pin flatbuffers checkout to the last release tag (#2823). (#2879) [Relay][Text Format] Reverse CallNode Print Order (#2882) [NNPACK] Modernize test (#2868) [Relay] Add list update to prelude (#2866) Add missing sgx includes (#2878) Fix setting up hints for getaddrinfo (#2872) [ARITH] RewriteSimplifier: improved cmp simplification (#2851) do (#2883) [RELAY][Frontend][TF] decompile tf control flow (#2830) * decompile tf control flow * Add docs * remove import relay * move tests under tensorflow frontend * minor fix Enhance upsample operator to adapt onnx opset version 9 (#2840) Use version invariant rustfmt (#2886) [Relay][Op] Add group conv2d dispatch to topi function (#2870) * [Relay][Op] Add group conv2d dispatch to topi function * Rerun tests [Apps] [howto_deploy] fix cxx-flags order and build directory (#2888) fix prelu, now can use on 2d input and add one test (#2875) Add dense schedules to __init__ for cpu (#2855) * Add dense schedules to __init__ for cpu * Add documentation for topi::shape * Add additional imports to topi CPU __init__. [TESTS] Improve script robustness (#2893) A number of test scripts use the '|| exit 1' idiom. This has two issues, first process exit codes are defined to be in the range 0-255. Second, more importantly, the idiom is fragile because it requires that every possible failure point be explicitly coded. This patch removes the idiom in favour of "set -e" as used in the docker scripts as a more robust mechanism to ensure that script failures are always caught and propagated by default. [Relay] Fix name of bias in testing.mlp (#2892) winograd_nnpack (#2721) [Relay] Fix Relay ARM CPU depthwise spatial pack schedule alter op layout issue. (#2861) * Fix Relay ARM CPU spatial pack depthwise alter op layout issue. * Update tune_relay_arm.py [TESTS] Import script robustness (set -u) (#2896) Adopt the "set -u" idiom from the docker scripts as a mechanism to improve future robustness. [DOCKER] Upgrade ci-cpu to latest v0.50 (#2901) Allow linking against MKLML (#2902) [COMMUNITY] ASF mentors (#2906) [Relay] Allow converting keras.layers.Sequential (#2842) * Allow converting keras.layers.Sequential * Use existing new_var function * Only update expr when missing * Add test [Relay] clean up hd, change tl (#2917) Turn on USE_SORT by default (#2916) [TEST] Cache test data (#2921) Unified error handling in NNVM and Relay frontends (#2828) add support for mxnet smooth_l1 (#2905) [Relay] Add support for TupleGetItem in op fusion (#2914) [Relay, TOPI] Deformable conv2d (#2908) * [Relay, TOPI] Add deformable conv2d * Moved to op level2 * Fix lint * Moved to level2 & bug fix * Update comments * Disabled flaky test of conv2d TVM debugresult dump to Chrome Tracing (#2922) [Relay] add test for second order ad (#2754) * do second order * add comment * better name * use tvm assert all close * refire ci Revert "[Relay] add test for second order ad (#2754)" (#2926) This reverts commit f5ca9915ab163364c885de0b103579e4d85460eb. [Tutorial] Cache the test data in tutorial (#2923) [AUTOTVM] Refactor measure build func (#2927) Fix intersect of modular set (#2904) Fix comment bugs and code style [Relay, OpFusion] Fix handling TupleGetItem for nested tuples (#2929) Consistent result of DetectLinearEquation() when an empy vars is passed (#2860) [FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. (#2850) * [FRONTEND][ONNX] Some bug fixes and Shape operator fixed for relay. * * test cases * * ci error Outdated renaming for flatten in ONNX converter (#2843) [FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. (#2864) * [FRONTEND][TENSORFLOW] bug fix for tensorflow official slim models. * * review comments Fix vcvtph2ps codegen (#2925) Port changes More fixes save save Changes to schedules and mxnet importer save save save save save remove remove save save --- include/tvm/relay/expr.h | 41 +- include/tvm/relay/expr_functor.h | 4 + python/tvm/relay/op/nn/_nn.py | 2 +- src/relay/ir/expr.cc | 15 +- src/relay/ir/expr_functor.cc | 8 + src/relay/ir/type.cc | 1 + src/relay/op/nn/nn.cc | 2 +- src/relay/op/tensor/transform.cc | 14 +- src/relay/pass/partial_eval.cc | 392 +++++++++++++++---- tests/python/relay/test_pass_partial_eval.py | 183 +++++---- 10 files changed, 518 insertions(+), 144 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 1d2fa5472993f..18ec944f54f52 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -184,6 +184,26 @@ class VarNode : public ExprNode { RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); +/*! \brief Hash Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarHash to hash Var by id. + */ +struct VarHash { + size_t operator()(const Var& v) const { + return v->vid.hash(); + } +}; + +/*! \brief Compare Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarEqual to compare Var by id. + */ +struct VarEqual { + bool operator()(const Var& l, const Var& r) const { + return l->vid.get() == r->vid.get(); + } +}; + /*! * \brief Global variable that leaves in the top-level module. * This is used to enable recursive calls between function. @@ -521,7 +541,7 @@ RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); * rewriting pass such as layout or type transformation. * * Subclass TempExprNode allows us to pattern match on - * specific kind TempExpr and use them for expression rewriting. + * specific kind of TempExpr and use them for expression rewriting. * * TempExpr should only be used within a pass, */ @@ -539,6 +559,25 @@ class TempExprNode : public ExprNode { RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); +class Annotate; +class AnnotateNode : public ExprNode { + public: + Expr expr; + NodeRef annotation; + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + v->Visit("annotation", &annotation); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Annotate make(Expr expr, NodeRef annotation); + + static constexpr const char* _type_key = "relay.AnnotateNode"; + TVM_DECLARE_NODE_TYPE_INFO(AnnotateNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Annotate, AnnotateNode, Expr); + // implementataions inline const Type& ExprNode::checked_type() const { CHECK(checked_type_.defined()) << "internal error: the type checker has " diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 3b179f8e53300..d3154d28bb272 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -116,6 +116,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AnnotateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -140,6 +141,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); + RELAY_EXPR_FUNCTOR_DISPATCH(AnnotateNode); return vtable; } }; @@ -170,6 +172,7 @@ class ExprVisitor void VisitExpr_(const RefWriteNode* op) override; void VisitExpr_(const ConstructorNode* op) override; void VisitExpr_(const MatchNode* op) override; + void VisitExpr_(const AnnotateNode* op) override; virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); @@ -212,6 +215,7 @@ class ExprMutator Expr VisitExpr_(const RefWriteNode* op) override; Expr VisitExpr_(const ConstructorNode* op) override; Expr VisitExpr_(const MatchNode* op) override; + Expr VisitExpr_(const AnnotateNode* op) override; /*! * \brief Used to visit the types inside of expressions. diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5a47b1d42ed31..f1e3942743dbb 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -74,7 +74,7 @@ def schedule_batch_matmul(attrs, outputs, target): with target: return topi.generic.schedule_batch_matmul(outputs) -reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_pattern("nn.batch_matmul", reg.OpPattern.OPAQUE) # conv2d diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3108bc2501fed..422163758a2f2 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -232,8 +232,7 @@ TVM_REGISTER_API("relay._make.Call") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; + p->stream << "CallNode(" << node->op << ")"; }); Let LetNode::make(Var var, Expr value, Expr body) { @@ -349,5 +348,17 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") *ret = temp->Realize(); }); +Annotate AnnotateNode::make(Expr expr, NodeRef annotation) { + NodePtr n = make_node(); + n->expr = std::move(expr); + n->annotation = std::move(annotation); + return Annotate(n); +} + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const AnnotateNode* node, tvm::IRPrinter* p) { + p->stream << "AnnotateNode(" << node->expr << ")"; + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index d0cd30adda29f..aaaf34d261a17 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -221,6 +221,10 @@ Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } Type ExprMutator::VisitType(const Type& t) { return t; } +Expr ExprMutator::VisitExpr_(const AnnotateNode* op) { + return AnnotateNode::make(VisitExpr(op->expr), op->annotation); +} + void ExprVisitor::VisitExpr(const Expr& expr) { auto it = visit_counter_.find(expr.get()); if (it != visit_counter_.end()) { @@ -315,6 +319,10 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) { } } +void ExprVisitor::VisitExpr_(const AnnotateNode* op) { + this->VisitExpr(op->expr); +} + void ExprVisitor::VisitClause(const Clause& op) { this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index fb0d919b46c38..718bad63693b1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -113,6 +113,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeCall TypeCallNode::make(Type func, tvm::Array args) { + CHECK(func.as()); NodePtr n = make_node(); n->func = std::move(func); n->args = std::move(args); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d24431347f808..c2f24e2179bef 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -683,7 +683,7 @@ bool BatchMatmulRel(const Array& types, const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; - if (x->shape.size() != 3 || y->shape.size() != 3) return false; + CHECK (x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " << " x shape=" << x->shape diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f86156bdbddcd..0f83f2cf194f2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -729,9 +729,19 @@ bool TakeRel(const Array& types, // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "must be tensor type or incomplete type"; + return false; + } + const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + CHECK(types[1].as()) + << "must be tensor type or incomplete type"; + return true; + } + const auto param = attrs.as(); CHECK(param != nullptr); diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index f6283d380176a..63ad11d2e1182 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - /*! * Copyright (c) 2018 by Contributors * @@ -104,6 +85,7 @@ #include #include #include +#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" @@ -112,26 +94,7 @@ namespace relay { using namespace runtime; -/*! \brief Hash Var by it's id. - * Different VarNode might has same vid, and they are considered to be the same var in such case. - * Use VarHash to hash Var by id. - */ -struct VarHash { - size_t operator()(const Var& v) const { - return v->vid.hash(); - } -}; - -/*! \brief Compare Var by it's id. - * Different VarNode might has same vid, and they are considered to be the same var in such case. - * Use VarEqual to compare Var by id. - */ -struct VarEqual { - bool operator()(const Var& l, const Var& r) const { - return l->vid.get() == r->vid.get(); - } -}; - +Expr PostProcess(const Expr&); /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: @@ -150,10 +113,20 @@ class Static : public NodeRef { using ContainerType = StaticNode; }; +using Time = size_t; + struct PStaticNode : Node { + static Time time() { + static Time time_ = 0; + Time ret = time_; + time_++; + return ret; + } Static pstatic; // may be null Expr dynamic; - PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + Time created_time; + PStaticNode(const Static& pstatic, const Expr& dynamic) : + pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -261,7 +234,7 @@ class Environment { } ++rit; } - LOG(FATAL) << "Unknown Variable: " << v; + LOG(FATAL) << "Unknown Variable: " << v << v.as(); throw; } @@ -341,6 +314,7 @@ class Store { }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { + CHECK(stat.defined()); return PStatic(make_node(stat, dynamic)); } @@ -383,15 +357,61 @@ FInterpreter CPUInterpreter() { return CreateInterpreter(Module(nullptr), CPUContext(), target); } -class PartialEvaluator : public ExprFunctor, +bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + +using FuncId = size_t; + +struct WithFuncId; + +struct WithFuncIdNode : Node { + FuncId fid; + WithFuncIdNode(FuncId fid) : fid(fid) { } + static constexpr const char* _type_key = "relay.WithFuncId"; + TVM_DECLARE_NODE_TYPE_INFO(WithFuncIdNode, Node); +}; + +RELAY_DEFINE_NODE_REF(WithFuncId, WithFuncIdNode, NodeRef); + +Annotate MkWithFuncId(const Expr& expr, FuncId fid) { + return AnnotateNode::make(expr, WithFuncId(make_node(fid))); +} + +Expr StripWithFuncId(const Expr& e); + +Expr DeDup(const Expr& e); + +Function AsFunc(const Expr& e) { + if (e.as()) { + return Downcast(e); + } else if (const AnnotateNode* a = e.as()) { + CHECK(a->annotation.as()); + return AsFunc(a->expr); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } +} + +class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars) { + PartialEvaluator(const tvm::Array& free_vars, + const Module& mod) : + mod_(mod) { for (const Var& v : free_vars) { env_.Insert(v, NoStatic(v)); } } + size_t depth = 0; + PStatic VisitExpr(const Expr& e, LetList* ll) final { + PStatic ret = ExprFunctor::VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -421,7 +441,20 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - return NoStatic(GetRef(op)); + GlobalVar gv = GetRef(op); + if (gv_map_.count(gv) == 0) { + if (mod_.defined()) { + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); + } else { + gv_map_.insert({gv, NoStatic(gv)}); + } + } + return gv_map_.at(gv); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { @@ -501,19 +534,45 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - Function func = GetRef(op); + PStatic VisitExpr_(const AnnotateNode* op, LetList* ll) final { + CHECK(op->annotation.as()); + return VisitExpr(op->expr, ll); + } + + struct TimeFrame { + PartialEvaluator* pe_; + FuncId fid_; + std::vector